36 #ifndef GRT_SVM_HEADER
37 #define GRT_SVM_HEADER
39 #include "../../CoreModules/Classifier.h"
46 #define SVM_MIN_SCALE_RANGE -1.0
47 #define SVM_MAX_SCALE_RANGE 1.0
69 SVM(UINT kernelType = LINEAR_KERNEL,UINT svmType = C_SVC,
bool useScaling =
true,
bool useNullRejection =
false,
bool useAutoGamma =
true,
double gamma = 0.1,UINT degree = 3,
double coef0 = 0,
double nu = 0.5,
double C = 1,
bool useCrossValidation =
false,UINT kFoldValue = 10);
89 SVM &operator=(
const SVM &rhs);
98 virtual bool deepCopyFrom(
const Classifier *classifier);
116 virtual bool predict_(VectorDouble &inputVector);
121 virtual bool clear();
130 virtual bool saveModelToFile(fstream &file)
const;
139 virtual bool loadModelFromFile(fstream &file);
157 bool init(UINT kernelType,UINT svmType,
bool useScaling,
bool useNullRejection,
bool useAutoGamma,
double gamma,UINT degree,
double coef0,
double nu,
double C,
bool useCrossValidation,UINT kFoldValue);
162 void initDefaultSVMSettings();
169 bool getIsCrossValidationTrainingEnabled()
const;
177 bool getIsAutoGammaEnabled()
const;
186 string getSVMType()
const;
195 string getKernelType()
const;
202 UINT getDegree()
const;
211 virtual UINT getNumClasses()
const;
218 double getGamma()
const;
225 double getNu()
const;
232 double getCoef0()
const;
246 double getCrossValidationResult()
const;
250 struct svm_model *getModel()
const {
return model; }
259 bool setSVMType(
const UINT svmType);
268 bool setKernelType(
const UINT kernelType);
276 bool setGamma(
const double gamma);
285 bool setDegree(
const UINT degree);
294 bool setNu(
const double nu);
303 bool setCoef0(
const double coef0);
312 bool setC(
const double C);
320 bool setKFoldCrossValidationValue(
const UINT kFoldValue);
328 bool enableAutoGamma(
const bool useAutoGamma);
336 bool enableCrossValidationTraining(
const bool useCrossValidation);
347 void deleteProblemSet();
348 bool validateProblemAndParameters();
349 bool validateSVMType(UINT svmType);
350 bool validateKernelType(UINT kernelType);
354 bool predictSVM(VectorDouble &inputVector);
355 bool predictSVM(VectorDouble &inputVector,
double &maxProbability, VectorDouble &probabilites);
356 bool loadLegacyModelFromFile( fstream &file );
359 bool deepCopyProblem(
const struct svm_problem &source_problem,
struct svm_problem &target_problem,
const unsigned int numInputDimensions )
const;
367 double classificationThreshold;
368 double crossValidationResult;
370 bool useCrossValidation;
376 enum SVMTypes{ C_SVC = 0, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR };
377 enum SVMKernelTypes{ LINEAR_KERNEL = 0, POLY_KERNEL, RBF_KERNEL, SIGMOID_KERNEL, PRECOMPUTED_KERNEL };
383 #endif //GRT_SVM_HEADER
virtual bool saveModelToFile(string filename) const
virtual bool loadModelFromFile(string filename)
virtual bool train(ClassificationData trainingData)
virtual bool predict(VectorDouble inputVector)
virtual bool predict_(VectorDouble &inputVector)
virtual bool train_(ClassificationData &trainingData)