GestureRecognitionToolkit  Version: 1.0 Revision: 04-03-15
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
SVM.h
Go to the documentation of this file.
1 
36 #ifndef GRT_SVM_HEADER
37 #define GRT_SVM_HEADER
38 
39 #include "../../CoreModules/Classifier.h"
40 #include "LIBSVM/libsvm.h"
41 
42 namespace GRT {
43 
44 using namespace LIBSVM;
45 
46 #define SVM_MIN_SCALE_RANGE -1.0
47 #define SVM_MAX_SCALE_RANGE 1.0
48 
49 class SVM : public Classifier{
50 public:
69  SVM(UINT kernelType = LINEAR_KERNEL,UINT svmType = C_SVC,bool useScaling = true,bool useNullRejection = false,bool useAutoGamma = true,double gamma = 0.1,UINT degree = 3,double coef0 = 0,double nu = 0.5,double C = 1,bool useCrossValidation = false,UINT kFoldValue = 10);
70 
76  SVM(const SVM &rhs);
77 
81  virtual ~SVM();
82 
89  SVM &operator=(const SVM &rhs);
90 
98  virtual bool deepCopyFrom(const Classifier *classifier);
99 
107  virtual bool train_(ClassificationData &trainingData);
108 
116  virtual bool predict_(VectorDouble &inputVector);
117 
121  virtual bool clear();
122 
130  virtual bool saveModelToFile(fstream &file) const;
131 
139  virtual bool loadModelFromFile(fstream &file);
140 
157  bool init(UINT kernelType,UINT svmType,bool useScaling,bool useNullRejection,bool useAutoGamma,double gamma,UINT degree,double coef0,double nu,double C,bool useCrossValidation,UINT kFoldValue);
158 
162  void initDefaultSVMSettings();
163 
169  bool getIsCrossValidationTrainingEnabled() const;
170 
177  bool getIsAutoGammaEnabled() const;
178 
186  string getSVMType() const;
187 
195  string getKernelType() const;
196 
202  UINT getDegree() const;
203 
211  virtual UINT getNumClasses() const;
212 
218  double getGamma() const;
219 
225  double getNu() const;
226 
232  double getCoef0() const;
233 
239  double getC() const;
240 
246  double getCrossValidationResult() const;
247 
248 
249 
250  struct svm_model *getModel() const { return model; }
251 
259  bool setSVMType(const UINT svmType);
260 
268  bool setKernelType(const UINT kernelType);
269 
276  bool setGamma(const double gamma);
277 
285  bool setDegree(const UINT degree);
286 
294  bool setNu(const double nu);
295 
303  bool setCoef0(const double coef0);
304 
312  bool setC(const double C);
313 
320  bool setKFoldCrossValidationValue(const UINT kFoldValue);
321 
328  bool enableAutoGamma(const bool useAutoGamma);
329 
336  bool enableCrossValidationTraining(const bool useCrossValidation);
337 
338  //Tell the compiler we are using the following functions from the MLBase class to stop hidden virtual function warnings
341  using MLBase::train;
342  using MLBase::train_;
343  using MLBase::predict;
344  using MLBase::predict_;
345 
346 protected:
347  void deleteProblemSet();
348  bool validateProblemAndParameters();
349  bool validateSVMType(UINT svmType);
350  bool validateKernelType(UINT kernelType);
351  bool convertClassificationDataToLIBSVMFormat(ClassificationData &trainingData);
352  bool trainSVM();
353 
354  bool predictSVM(VectorDouble &inputVector);
355  bool predictSVM(VectorDouble &inputVector,double &maxProbability, VectorDouble &probabilites);
356  bool loadLegacyModelFromFile( fstream &file );
357 
358  struct svm_model *deepCopyModel() const;
359  bool deepCopyProblem( const struct svm_problem &source_problem, struct svm_problem &target_problem, const unsigned int numInputDimensions ) const;
360  bool deepCopyParam( const svm_parameter &source_param, svm_parameter &target_param ) const;
361 
362  bool problemSet;
363  struct svm_model *model;
364  struct svm_parameter param;
365  struct svm_problem prob;
366  UINT kFoldValue;
367  double classificationThreshold;
368  double crossValidationResult;
369  bool useAutoGamma;
370  bool useCrossValidation;
371 
372  static RegisterClassifierModule< SVM > registerModule;
373 
374 public:
375 
376  enum SVMTypes{ C_SVC = 0, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR };
377  enum SVMKernelTypes{ LINEAR_KERNEL = 0, POLY_KERNEL, RBF_KERNEL, SIGMOID_KERNEL, PRECOMPUTED_KERNEL };
378 
379 };
380 
381 } //End of namespace GRT
382 
383 #endif //GRT_SVM_HEADER
384 
Definition: SVM.h:49
virtual bool saveModelToFile(string filename) const
Definition: MLBase.cpp:135
virtual bool loadModelFromFile(string filename)
Definition: MLBase.cpp:157
Definition: AdaBoost.cpp:25
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:80
virtual bool predict(VectorDouble inputVector)
Definition: MLBase.cpp:104
virtual bool predict_(VectorDouble &inputVector)
Definition: MLBase.cpp:106
Definition: libsvm.cpp:4
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:82