GestureRecognitionToolkit  Version: 1.0 Revision: 04-03-15
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DTW.h
Go to the documentation of this file.
1 
43 #ifndef GRT_DTW_HEADER
44 #define GRT_DTW_HEADER
45 
46 #include "../../CoreModules/Classifier.h"
47 #include "../../Util/TimeSeriesClassificationSampleTrimmer.h"
48 
49 namespace GRT{
50 
51 class IndexDist{
52 public:
53  IndexDist(int x=0,int y=0,double dist=0){
54  this->x = x;
55  this->y = y;
56  this->dist = dist;
57  }
58  ~IndexDist(){};
59  IndexDist& operator=(const IndexDist &rhs){
60  if(this!=&rhs){
61  this->x = rhs.x;
62  this->y = rhs.y;
63  this->dist = rhs.dist;
64  }
65  return (*this);
66  }
67 
68  int x;
69  int y;
70  double dist;
71 };
72 
75 public:
76  DTWTemplate(){
77  classLabel = 0;
78  trainingMu = 0.0;
79  trainingSigma = 0.0;
80  averageTemplateLength=0;
81  }
82  ~DTWTemplate(){};
83 
84  UINT classLabel; //The class that this template belongs to
85  MatrixDouble timeSeries; //The raw time series
86  double trainingMu; //The mean distance value of the training data with the trained template
87  double trainingSigma; //The sigma of the distance value of the training data with the trained template
88  UINT averageTemplateLength; //The average length of the examples used to train this template
89 };
90 
91 class DTW : public Classifier
92 {
93 public:
94 
108  DTW(bool useScaling=false,bool useNullRejection=false,double nullRejectionCoeff=3.0,UINT rejectionMode = DTW::TEMPLATE_THRESHOLDS,bool dtwConstrain=true,double radius=0.2,bool offsetUsingFirstSample=false,bool useSmoothing = false,UINT smoothingFactor = 5);
109 
117  DTW(const DTW &rhs);
118 
122  virtual ~DTW(void);
123 
130  DTW& operator=(const DTW &rhs);
131 
139  virtual bool deepCopyFrom(const Classifier *classifier);
140 
148  virtual bool train_(TimeSeriesClassificationData &trainingData);
149 
157  virtual bool predict_(VectorDouble &inputVector);
158 
166  virtual bool predict_(MatrixDouble &timeSeries);
167 
173  virtual bool reset();
174 
181  virtual bool clear();
182 
190  virtual bool saveModelToFile(fstream &file) const;
191 
199  virtual bool loadModelFromFile(fstream &file);
200 
208  virtual bool recomputeNullRejectionThresholds();
209 
215  UINT getNumTemplates(){ return numTemplates; }
216 
222  bool setRejectionMode(UINT rejectionMode);
223 
232  bool setOffsetTimeseriesUsingFirstSample(bool offsetUsingFirstSample);
233 
240  bool setContrainWarpingPath(bool constrain);
241 
250  bool setWarpingRadius(double radius);
251 
257  UINT getRejectionMode(){ return rejectionMode; }
258 
266  bool enableZNormalization(bool useZNormalization,bool constrainZNorm = true);
267 
284  bool enableTrimTrainingData(bool trimTrainingData,double trimThreshold,double maximumTrimPercentage);
285 
291  vector< DTWTemplate > getModels(){ return templatesBuffer; }
292 
298  bool setModels( vector< DTWTemplate > newTemplates );
299 
305  vector< VectorDouble > getInputDataBuffer(){ return continuousInputDataBuffer.getDataAsVector(); }
306 
312  vector< MatrixDouble > getDistanceMatrices(){ return distanceMatrices; }
313 
319  vector< vector< IndexDist > > getWarpingPaths(){ return warpPaths; }
320 
321  //Tell the compiler we are using the base class train method to stop hidden virtual function warnings
324  using MLBase::train;
325  using MLBase::predict;
326 
327 private:
328  //Public training and prediction methods
329  bool train_NDDTW(TimeSeriesClassificationData &trainingData,DTWTemplate &dtwTemplate,UINT &bestIndex);
330 
331  //The actual DTW function
332  double computeDistance(MatrixDouble &timeSeriesA,MatrixDouble &timeSeriesB,MatrixDouble &distanceMatrix,vector< IndexDist > &warpPath);
333  double d(int m,int n,MatrixDouble &distanceMatrix,const int M,const int N);
334  double inline MIN_(double a,double b, double c);
335 
336  //Private Scaling and Utility Functions
337  void scaleData(TimeSeriesClassificationData &trainingData);
338  void scaleData(MatrixDouble &data,MatrixDouble &scaledData);
339  void znormData(TimeSeriesClassificationData &trainingData);
340  void znormData(MatrixDouble &data,MatrixDouble &normData);
341  void smoothData(VectorDouble &data,UINT smoothFactor,VectorDouble &resultsData);
342  void smoothData(MatrixDouble &data,UINT smoothFactor,MatrixDouble &resultsData);
343  void offsetTimeseries(MatrixDouble &timeseries);
344 
345  static RegisterClassifierModule< DTW > registerModule;
346 
347 protected:
348  bool loadLegacyModelFromFile( fstream &file );
349 
350  vector< DTWTemplate > templatesBuffer; //A buffer to store the templates for each time series
351  vector< MatrixDouble > distanceMatrices;
352  vector< vector< IndexDist > > warpPaths;
353  CircularBuffer< VectorDouble > continuousInputDataBuffer;
354  UINT numTemplates; //The number of templates in our buffer
355  UINT rejectionMode; //The rejection mode used to reject null gestures during the prediction phase
356 
357  //Flags
358  bool useSmoothing; //A flag to check if we need to smooth the data
359  bool useZNormalisation; //A flag to check if we need to znorm the training and prediction data
360  bool offsetUsingFirstSample; //A flag to check if each timeseries should be offset by the first sample in the time series
361  bool constrainZNorm; //A flag to check if we need to constrain zNorm (only zNorm if stdDev > zNormConstrainThreshold)
362  bool constrainWarpingPath; //A flag to check if we need to constrain the dtw cost matrix and search
363  bool trimTrainingData; //A flag to check if we need to trim the training data first before training
364 
365  double zNormConstrainThreshold;//The threshold value to be used if constrainZNorm is turned on
366  double radius;
367  double trimThreshold; //Sets the threshold under which training data should be trimmed (default 0.1)
368  double maximumTrimPercentage; //Sets the maximum amount of data that can be trimmed for each training sample (default 20)
369 
370  UINT smoothingFactor; //The smoothing factor if smoothing is used
371  UINT distanceMethod; //The distance method to be used (should be of enum DISTANCE_METHOD)
372  UINT averageTemplateLength; //The overall average template length (over all the templates)
373 
374 public:
375  enum DistanceMethods{ABSOLUTE_DIST=0,EUCLIDEAN_DIST,NORM_ABSOLUTE_DIST};
376  enum RejectionModes{TEMPLATE_THRESHOLDS=0,CLASS_LIKELIHOODS,THRESHOLDS_AND_LIKELIHOODS};
377 
378 };
379 
380 }//End of namespace GRT
381 
382 #endif //GRT_DTW_HEADER
virtual bool saveModelToFile(string filename) const
Definition: MLBase.cpp:135
virtual bool reset()
Definition: DTW.cpp:505
bool setRejectionMode(UINT rejectionMode)
Definition: DTW.cpp:1219
bool enableZNormalization(bool useZNormalization, bool constrainZNorm=true)
Definition: DTW.cpp:1242
virtual bool loadModelFromFile(string filename)
Definition: MLBase.cpp:157
Definition: AdaBoost.cpp:25
DTW(bool useScaling=false, bool useNullRejection=false, double nullRejectionCoeff=3.0, UINT rejectionMode=DTW::TEMPLATE_THRESHOLDS, bool dtwConstrain=true, double radius=0.2, bool offsetUsingFirstSample=false, bool useSmoothing=false, UINT smoothingFactor=5)
Definition: DTW.cpp:30
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:80
Definition: DTW.h:91
virtual ~DTW(void)
Definition: DTW.cpp:70
bool setOffsetTimeseriesUsingFirstSample(bool offsetUsingFirstSample)
Definition: DTW.cpp:1227
virtual bool deepCopyFrom(const Classifier *classifier)
Definition: DTW.cpp:105
virtual bool predict(VectorDouble inputVector)
Definition: MLBase.cpp:104
UINT getNumTemplates()
Definition: DTW.h:215
vector< MatrixDouble > getDistanceMatrices()
Definition: DTW.h:312
DTW & operator=(const DTW &rhs)
Definition: DTW.cpp:74
bool setContrainWarpingPath(bool constrain)
Definition: DTW.cpp:1232
UINT getRejectionMode()
Definition: DTW.h:257
virtual bool loadModelFromFile(fstream &file)
Definition: DTW.cpp:994
virtual bool clear()
Definition: DTW.cpp:514
vector< DTWTemplate > getModels()
Definition: DTW.h:291
vector< vector< IndexDist > > getWarpingPaths()
Definition: DTW.h:319
vector< T > getDataAsVector() const
bool setWarpingRadius(double radius)
Definition: DTW.cpp:1237
virtual bool predict_(VectorDouble &inputVector)
Definition: DTW.cpp:466
virtual bool recomputeNullRejectionThresholds()
Definition: DTW.cpp:528
vector< VectorDouble > getInputDataBuffer()
Definition: DTW.h:305
bool setModels(vector< DTWTemplate > newTemplates)
Definition: DTW.cpp:543
bool enableTrimTrainingData(bool trimTrainingData, double trimThreshold, double maximumTrimPercentage)
Definition: DTW.cpp:1248
virtual bool train_(TimeSeriesClassificationData &trainingData)
Definition: DTW.cpp:140
virtual bool saveModelToFile(fstream &file) const
Definition: DTW.cpp:933