GestureRecognitionToolkit  Version: 1.0 Revision: 04-03-15
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
DecisionTree.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 */
20 
21 #include "DecisionTree.h"
22 
23 using namespace GRT;
24 
25 //Register the DecisionTree module with the Classifier base class
26 RegisterClassifierModule< DecisionTree > DecisionTree::registerModule("DecisionTree");
27 
28 DecisionTree::DecisionTree(const DecisionTreeNode &decisionTreeNode,const UINT minNumSamplesPerNode,const UINT maxDepth,const bool removeFeaturesAtEachSpilt,const UINT trainingMode,const UINT numSplittingSteps,const bool useScaling)
29 {
30  this->tree = NULL;
31  this->decisionTreeNode = NULL;
32  this->minNumSamplesPerNode = minNumSamplesPerNode;
33  this->maxDepth = maxDepth;
34  this->removeFeaturesAtEachSpilt = removeFeaturesAtEachSpilt;
35  this->trainingMode = trainingMode;
36  this->numSplittingSteps = numSplittingSteps;
37  this->useScaling = useScaling;
38  this->supportsNullRejection = true;
39  Classifier::classType = "DecisionTree";
40  classifierType = Classifier::classType;
41  classifierMode = STANDARD_CLASSIFIER_MODE;
42  Classifier::debugLog.setProceedingText("[DEBUG DecisionTree]");
43  Classifier::errorLog.setProceedingText("[ERROR DecisionTree]");
44  Classifier::trainingLog.setProceedingText("[TRAINING DecisionTree]");
45  Classifier::warningLog.setProceedingText("[WARNING DecisionTree]");
46 
47  this->decisionTreeNode = decisionTreeNode.deepCopy();
48 
49 }
50 
52  tree = NULL;
53  decisionTreeNode = NULL;
54  Classifier::classType = "DecisionTree";
55  classifierType = Classifier::classType;
56  classifierMode = STANDARD_CLASSIFIER_MODE;
57  Classifier:: debugLog.setProceedingText("[DEBUG DecisionTree]");
58  Classifier::errorLog.setProceedingText("[ERROR DecisionTree]");
59  Classifier::trainingLog.setProceedingText("[TRAINING DecisionTree]");
60  Classifier::warningLog.setProceedingText("[WARNING DecisionTree]");
61  *this = rhs;
62 }
63 
65 {
66  clear();
67 
68  if( decisionTreeNode != NULL ){
69  delete decisionTreeNode;
70  decisionTreeNode = NULL;
71  }
72 }
73 
75  if( this != &rhs ){
76  //Clear this tree
77  clear();
78 
79  if( rhs.getTrained() ){
80  //Deep copy the tree
81  this->tree = (DecisionTreeNode*)rhs.deepCopyTree();
82  }
83 
84  //Deep copy the main node
85  if( this->decisionTreeNode != NULL ){
86  delete decisionTreeNode;
87  decisionTreeNode = NULL;
88  }
89  this->decisionTreeNode = rhs.deepCopyDecisionTreeNode();
90 
91  this->minNumSamplesPerNode = rhs.minNumSamplesPerNode;
92  this->maxDepth = rhs.maxDepth;
93  this->removeFeaturesAtEachSpilt = rhs.removeFeaturesAtEachSpilt;
94  this->trainingMode = rhs.trainingMode;
95  this->numSplittingSteps = rhs.numSplittingSteps;
96  this->nodeClusters = rhs.nodeClusters;
97 
98  //Copy the base classifier variables
99  copyBaseVariables( (Classifier*)&rhs );
100  }
101  return *this;
102 }
103 
104 bool DecisionTree::deepCopyFrom(const Classifier *classifier){
105 
106  if( classifier == NULL ) return false;
107 
108  if( this->getClassifierType() == classifier->getClassifierType() ){
109 
110  DecisionTree *ptr = (DecisionTree*)classifier;
111 
112  //Clear this tree
113  this->clear();
114 
115  if( ptr->getTrained() ){
116  //Deep copy the tree
117  this->tree = ptr->deepCopyTree();
118  }
119 
120  //Deep copy the main node
121  if( this->decisionTreeNode != NULL ){
122  delete decisionTreeNode;
123  decisionTreeNode = NULL;
124  }
125  this->decisionTreeNode = ptr->deepCopyDecisionTreeNode();
126 
127  this->minNumSamplesPerNode = ptr->minNumSamplesPerNode;
128  this->maxDepth = ptr->maxDepth;
129  this->removeFeaturesAtEachSpilt = ptr->removeFeaturesAtEachSpilt;
130  this->trainingMode = ptr->trainingMode;
131  this->numSplittingSteps = ptr->numSplittingSteps;
132  this->nodeClusters = ptr->nodeClusters;
133 
134  //Copy the base classifier variables
135  return copyBaseVariables( classifier );
136  }
137  return false;
138 }
139 
141 
142  //Clear any previous model
143  clear();
144 
145  if( decisionTreeNode == NULL ){
146  Classifier::errorLog << "train_(ClassificationData &trainingData) - The decision tree node has not been set! You must set this first before training a model." << endl;
147  return false;
148  }
149 
150  const unsigned int M = trainingData.getNumSamples();
151  const unsigned int N = trainingData.getNumDimensions();
152  const unsigned int K = trainingData.getNumClasses();
153 
154  if( M == 0 ){
155  Classifier::errorLog << "train_(ClassificationData &trainingData) - Training data has zero samples!" << endl;
156  return false;
157  }
158 
159  numInputDimensions = N;
160  numClasses = K;
161  classLabels = trainingData.getClassLabels();
162  ranges = trainingData.getRanges();
163 
164  //Scale the training data if needed
165  if( useScaling ){
166  //Scale the training data between 0 and 1
167  trainingData.scale(0, 1);
168  }
169 
170  //If we are using null rejection, then we need a copy of the training dataset for later
171  ClassificationData trainingDataCopy;
172  if( useNullRejection ){
173  trainingDataCopy = trainingData;
174  }
175 
176  //Setup the valid features - at this point all features can be used
177  vector< UINT > features(N);
178  for(UINT i=0; i<N; i++){
179  features[i] = i;
180  }
181 
182  //Build the tree
183  UINT nodeID = 0;
184  tree = buildTree( trainingData, NULL, features, classLabels, nodeID );
185 
186  if( tree == NULL ){
187  clear();
188  Classifier::errorLog << "train_(ClassificationData &trainingData) - Failed to build tree!" << endl;
189  return false;
190  }
191 
192  //Flag that the algorithm has been trained
193  trained = true;
194 
195  //Compute the null rejection thresholds if null rejection is enabled
196  if( useNullRejection ){
197  VectorDouble classLikelihoods( numClasses );
198  vector< UINT > predictions(M);
199  VectorDouble distances(M);
200  VectorDouble classCounter( numClasses, 0 );
201 
202  //Run over the training dataset and compute the distance between each training sample and the predicted node cluster
203  for(UINT i=0; i<M; i++){
204  //Run the prediction for this sample
205  if( !tree->predict( trainingDataCopy[i].getSample(), classLikelihoods ) ){
206  Classifier::errorLog << "predict_(VectorDouble &inputVector) - Failed to predict!" << endl;
207  return false;
208  }
209 
210  //Store the predicted class index and cluster distance
211  predictions[i] = Util::getMaxIndex( classLikelihoods );
212  distances[i] = getNodeDistance(trainingDataCopy[i].getSample(), tree->getPredictedNodeID() );
213 
214  classCounter[ predictions[i] ]++;
215  }
216 
217  //Compute the average distance for each class between the training data and the node clusters
218  classClusterMean.clear();
219  classClusterStdDev.clear();
220  classClusterMean.resize( numClasses, 0 );
221  classClusterStdDev.resize( numClasses, 0.01 ); //we start the std dev with a small value to ensure it is not zero
222 
223  for(UINT i=0; i<M; i++){
224  classClusterMean[ predictions[i] ] += distances[ i ];
225  }
226  for(UINT k=0; k<numClasses; k++){
227  classClusterMean[k] /= MAX( classCounter[k], 1 );
228  }
229 
230  //Compute the std deviation
231  for(UINT i=0; i<M; i++){
232  classClusterStdDev[ predictions[i] ] += MLBase::SQR( distances[ i ] - classClusterMean[ predictions[i] ] );
233  }
234  for(UINT k=0; k<numClasses; k++){
235  classClusterStdDev[k] = sqrt( classClusterStdDev[k] / MAX( classCounter[k], 1 ) );
236  }
237 
238  //Compute the null rejection thresholds using the class mean and std dev
240 
241  }
242 
243  return true;
244 }
245 
246 bool DecisionTree::predict_(VectorDouble &inputVector){
247 
248  predictedClassLabel = 0;
249  maxLikelihood = 0;
250 
251  //Validate the input is OK and the model is trained properly
252  if( !trained ){
253  Classifier::errorLog << "predict_(VectorDouble &inputVector) - Model Not Trained!" << endl;
254  return false;
255  }
256 
257  if( tree == NULL ){
258  Classifier::errorLog << "predict_(VectorDouble &inputVector) - DecisionTree pointer is null!" << endl;
259  return false;
260  }
261 
262  if( inputVector.size() != numInputDimensions ){
263  Classifier::errorLog << "predict_(VectorDouble &inputVector) - The size of the input vector (" << inputVector.size() << ") does not match the num features in the model (" << numInputDimensions << endl;
264  return false;
265  }
266 
267  //Scale the input data if needed
268  if( useScaling ){
269  for(UINT n=0; n<numInputDimensions; n++){
270  inputVector[n] = scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue, 0, 1);
271  }
272  }
273 
274  if( classLikelihoods.size() != numClasses ) classLikelihoods.resize(numClasses,0);
275  if( classDistances.size() != numClasses ) classDistances.resize(numClasses,0);
276 
277  //Run the decision tree prediction
278  if( !tree->predict( inputVector, classLikelihoods ) ){
279  Classifier::errorLog << "predict_(VectorDouble &inputVector) - Failed to predict!" << endl;
280  return false;
281  }
282 
283  //Find the maximum likelihood
284  //The tree automatically returns proper class likelihoods so we don't need to do anything else
285  UINT maxIndex = 0;
286  maxLikelihood = 0;
287  for(UINT k=0; k<numClasses; k++){
288  if( classLikelihoods[k] > maxLikelihood ){
289  maxLikelihood = classLikelihoods[k];
290  maxIndex = k;
291  }
292  }
293 
294  //Run the null rejection
295  if( useNullRejection ){
296 
297  //Get the distance between the input and the leaf mean
298  double leafDistance = getNodeDistance( inputVector, tree->getPredictedNodeID() );
299 
300  if( grt_isnan(leafDistance) ){
301  Classifier::errorLog << "predict_(VectorDouble &inputVector) - Failed to match leaf node ID to compute node distance!" << endl;
302  return false;
303  }
304 
305  //Set the predicted class distance as the leaf distance, all other classes will have a distance of zero
306  std::fill(classDistances.begin(),classDistances.end(),0);
307  classDistances[ maxIndex ] = leafDistance;
308 
309  //Use the distance to check if the class label should be rejected or not
310  if( leafDistance <= nullRejectionThresholds[ maxIndex ] ){
311  predictedClassLabel = classLabels[ maxIndex ];
312  }else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
313 
314  }else {
315  //Set the predicated class label
316  predictedClassLabel = classLabels[ maxIndex ];
317  }
318 
319  return true;
320 }
321 
323 
324  //Clear the Classifier variables
326 
327  //Clear the node clusters
328  nodeClusters.clear();
329 
330  //Delete the tree if it exists
331  if( tree != NULL ){
332  tree->clear();
333  delete tree;
334  tree = NULL;
335  }
336 
337  //NOTE: We do not want to clean up the decisionTreeNode here as we need to keep track of this, this is only delete in the destructor
338 
339  return true;
340 }
341 
343 
344  if( !trained ){
345  Classifier::warningLog << "recomputeNullRejectionThresholds() - Failed to recompute null rejection thresholds, the model has not been trained!" << endl;
346  return false;
347  }
348 
349  if( !useNullRejection ){
350  Classifier::warningLog << "recomputeNullRejectionThresholds() - Failed to recompute null rejection thresholds, null rejection is not enabled!" << endl;
351  return false;
352  }
353 
354  nullRejectionThresholds.resize( numClasses );
355 
356  //Compute the rejection threshold for each class using the mean and std dev
357  for(UINT k=0; k<numClasses; k++){
358  nullRejectionThresholds[k] = classClusterMean[k] + (classClusterStdDev[k]*nullRejectionCoeff);
359  }
360 
361  return true;
362 }
363 
364 bool DecisionTree::saveModelToFile(fstream &file) const{
365 
366  if(!file.is_open())
367  {
368  Classifier::errorLog <<"saveModelToFile(fstream &file) - The file is not open!" << endl;
369  return false;
370  }
371 
372  //Write the header info
373  file << "GRT_DECISION_TREE_MODEL_FILE_V4.0\n";
374 
375  //Write the classifier settings to the file
377  Classifier::errorLog <<"saveModelToFile(fstream &file) - Failed to save classifier base settings to file!" << endl;
378  return false;
379  }
380 
381  if( decisionTreeNode != NULL ){
382  file << "DecisionTreeNodeType: " << decisionTreeNode->getNodeType() << endl;
383  if( !decisionTreeNode->saveToFile( file ) ){
384  Classifier::errorLog <<"saveModelToFile(fstream &file) - Failed to save decisionTreeNode settings to file!" << endl;
385  return false;
386  }
387  }else{
388  file << "DecisionTreeNodeType: " << "NULL" << endl;
389  }
390 
391  file << "MinNumSamplesPerNode: " << minNumSamplesPerNode << endl;
392  file << "MaxDepth: " << maxDepth << endl;
393  file << "RemoveFeaturesAtEachSpilt: " << removeFeaturesAtEachSpilt << endl;
394  file << "TrainingMode: " << trainingMode << endl;
395  file << "NumSplittingSteps: " << numSplittingSteps << endl;
396  file << "TreeBuilt: " << (tree != NULL ? 1 : 0) << endl;
397 
398  if( tree != NULL ){
399  file << "Tree:\n";
400  if( !tree->saveToFile( file ) ){
401  Classifier::errorLog << "saveModelToFile(fstream &file) - Failed to save tree to file!" << endl;
402  return false;
403  }
404 
405  //Save the null rejection data if needed
406  if( useNullRejection ){
407 
408  file << "ClassClusterMean:";
409  for(UINT k=0; k<numClasses; k++){
410  file << " " << classClusterMean[k];
411  }
412  file << endl;
413 
414  file << "ClassClusterStdDev:";
415  for(UINT k=0; k<numClasses; k++){
416  file << " " << classClusterStdDev[k];
417  }
418  file << endl;
419 
420  file << "NumNodes: " << nodeClusters.size() << endl;
421  file << "NodeClusters:\n";
422 
423  std::map< UINT, VectorDouble >::const_iterator iter = nodeClusters.begin();
424 
425  while( iter != nodeClusters.end() ){
426 
427  //Write the nodeID
428  file << iter->first;
429 
430  //Write the node cluster
431  for(UINT j=0; j<numInputDimensions; j++){
432  file << " " << iter->second[j];
433  }
434  file << endl;
435 
436  iter++;
437  }
438  }
439 
440  }
441 
442  return true;
443 }
444 
446 
447  clear();
448 
449  if( decisionTreeNode != NULL ){
450  delete decisionTreeNode;
451  decisionTreeNode = NULL;
452  }
453 
454  if( !file.is_open() )
455  {
456  Classifier::errorLog << "loadModelFromFile(string filename) - Could not open file to load model" << endl;
457  return false;
458  }
459 
460  std::string word;
461  file >> word;
462 
463  //Check to see if we should load a legacy file
464  if( word == "GRT_DECISION_TREE_MODEL_FILE_V1.0" ){
465  return loadLegacyModelFromFile_v1( file );
466  }
467 
468  if( word == "GRT_DECISION_TREE_MODEL_FILE_V2.0" ){
469  return loadLegacyModelFromFile_v2( file );
470  }
471 
472  if( word == "GRT_DECISION_TREE_MODEL_FILE_V3.0" ){
473  return loadLegacyModelFromFile_v3( file );
474  }
475 
476  //Find the file type header
477  if( word != "GRT_DECISION_TREE_MODEL_FILE_V4.0" ){
478  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find Model File Header" << endl;
479  return false;
480  }
481 
482  //Load the base settings from the file
484  Classifier::errorLog << "loadModelFromFile(string filename) - Failed to load base settings from file!" << endl;
485  return false;
486  }
487 
488  file >> word;
489  if(word != "DecisionTreeNodeType:"){
490  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the DecisionTreeNodeType!" << endl;
491  return false;
492  }
493  file >> word;
494 
495  if( word != "NULL" ){
496 
497  decisionTreeNode = dynamic_cast< DecisionTreeNode* >( DecisionTreeNode::createInstanceFromString( word ) );
498 
499  if( decisionTreeNode == NULL ){
500  Classifier::errorLog << "loadModelFromFile(string filename) - Could not create new DecisionTreeNode from type: " << word << endl;
501  return false;
502  }
503 
504  if( !decisionTreeNode->loadFromFile( file ) ){
505  Classifier::errorLog <<"loadModelFromFile(fstream &file) - Failed to load decisionTreeNode settings from file!" << endl;
506  return false;
507  }
508  }else{
509  Classifier::errorLog <<"loadModelFromFile(fstream &file) - Failed to load decisionTreeNode! DecisionTreeNodeType is NULL!" << endl;
510  return false;
511  }
512 
513  file >> word;
514  if(word != "MinNumSamplesPerNode:"){
515  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MinNumSamplesPerNode!" << endl;
516  return false;
517  }
518  file >> minNumSamplesPerNode;
519 
520  file >> word;
521  if(word != "MaxDepth:"){
522  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MaxDepth!" << endl;
523  return false;
524  }
525  file >> maxDepth;
526 
527  file >> word;
528  if(word != "RemoveFeaturesAtEachSpilt:"){
529  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << endl;
530  return false;
531  }
532  file >> removeFeaturesAtEachSpilt;
533 
534  file >> word;
535  if(word != "TrainingMode:"){
536  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TrainingMode!" << endl;
537  return false;
538  }
539  file >> trainingMode;
540 
541  file >> word;
542  if(word != "NumSplittingSteps:"){
543  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumSplittingSteps!" << endl;
544  return false;
545  }
546  file >> numSplittingSteps;
547 
548  file >> word;
549  if(word != "TreeBuilt:"){
550  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TreeBuilt!" << endl;
551  return false;
552  }
553  file >> trained;
554 
555  if( trained ){
556  file >> word;
557  if(word != "Tree:"){
558  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the Tree!" << endl;
559  return false;
560  }
561 
562  //Create a new DTree
563  tree = dynamic_cast< DecisionTreeNode* >( decisionTreeNode->createNewInstance() );
564 
565  if( tree == NULL ){
566  clear();
567  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to create new DecisionTreeNode!" << endl;
568  return false;
569  }
570 
571  tree->setParent( NULL );
572  if( !tree->loadFromFile( file ) ){
573  clear();
574  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to load tree from file!" << endl;
575  return false;
576  }
577 
578  //Load the null rejection data if needed
579  if( useNullRejection ){
580 
581  UINT numNodes = 0;
582  classClusterMean.resize( numClasses );
583  classClusterStdDev.resize( numClasses );
584 
585  file >> word;
586  if(word != "ClassClusterMean:"){
587  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the ClassClusterMean header!" << endl;
588  return false;
589  }
590  for(UINT k=0; k<numClasses; k++){
591  file >> classClusterMean[k];
592  }
593 
594  file >> word;
595  if(word != "ClassClusterStdDev:"){
596  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the ClassClusterStdDev header!" << endl;
597  return false;
598  }
599  for(UINT k=0; k<numClasses; k++){
600  file >> classClusterStdDev[k];
601  }
602 
603  file >> word;
604  if(word != "NumNodes:"){
605  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumNodes header!" << endl;
606  return false;
607  }
608  file >> numNodes;
609 
610  file >> word;
611  if(word != "NodeClusters:"){
612  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NodeClusters header!" << endl;
613  return false;
614  }
615 
616  UINT nodeID = 0;
617  VectorDouble cluster( numInputDimensions );
618  for(UINT i=0; i<numNodes; i++){
619 
620  //load the nodeID
621  file >> nodeID;
622 
623  for(UINT j=0; j<numInputDimensions; j++){
624  file >> cluster[j];
625  }
626 
627  //Add the cluster to the cluster nodes map
628  nodeClusters[ nodeID ] = cluster;
629  }
630 
631  //Recompute the null rejection thresholds
633  }
634 
635  //Resize the prediction results to make sure it is setup for realtime prediction
636  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
637  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
638  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
639  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
640  }
641 
642  return true;
643 }
644 
645 bool DecisionTree::getModel(ostream &stream) const{
646 
647  if( tree != NULL )
648  return tree->getModel( stream );
649  return false;
650 
651 }
652 
654 
655  if( tree == NULL ){
656  return NULL;
657  }
658 
659  return dynamic_cast< DecisionTreeNode* >( tree->deepCopyNode() );
660 }
661 
663 
664  if( decisionTreeNode == NULL ){
665  return NULL;
666  }
667 
668  return decisionTreeNode->deepCopy();
669 }
670 
672  return dynamic_cast< DecisionTreeNode* >( tree );
673 }
674 
676 
677  if( decisionTreeNode != NULL ){
678  delete decisionTreeNode;
679  decisionTreeNode = NULL;
680  }
681  this->decisionTreeNode = node.deepCopy();
682 
683  return true;
684 }
685 
686 DecisionTreeNode* DecisionTree::buildTree(ClassificationData &trainingData,DecisionTreeNode *parent,vector< UINT > features,const vector< UINT > &classLabels, UINT nodeID){
687 
688  const UINT M = trainingData.getNumSamples();
689  const UINT N = trainingData.getNumDimensions();
690 
691  //Update the nodeID
692  nodeID++;
693 
694  //Get the depth
695  UINT depth = 0;
696 
697  if( parent != NULL )
698  depth = parent->getDepth() + 1;
699 
700  //If there are no training data then return NULL
701  if( trainingData.getNumSamples() == 0 )
702  return NULL;
703 
704  //Create the new node
705  DecisionTreeNode *node = dynamic_cast< DecisionTreeNode* >( decisionTreeNode->createNewInstance() );
706 
707  if( node == NULL )
708  return NULL;
709 
710  //Set the parent
711  node->initNode( parent, depth, nodeID );
712 
713  //If all the training data belongs to the same class or there are no features left then create a leaf node and return
714  if( trainingData.getNumClasses() == 1 || features.size() == 0 || M < minNumSamplesPerNode || depth >= maxDepth ){
715 
716  //Set the node
717  node->setLeafNode( trainingData.getNumSamples(), trainingData.getClassProbabilities( classLabels ) );
718 
719  //Build the null cluster if null rejection is enabled
720  if( useNullRejection ){
721  nodeClusters[ nodeID ] = trainingData.getMean();
722  }
723 
724  Classifier::trainingLog << "Reached leaf node. Depth: " << depth << " NumSamples: " << trainingData.getNumSamples() << endl;
725 
726  return node;
727  }
728 
729  //Compute the best spilt point
730  UINT featureIndex = 0;
731  double minError = 0;
732 
733  if( !node->computeBestSpilt( trainingMode, numSplittingSteps, trainingData, features, classLabels, featureIndex, minError ) ){
734  delete node;
735  return NULL;
736  }
737 
738  Classifier::trainingLog << "Depth: " << depth << " FeatureIndex: " << featureIndex << " MinError: " << minError << endl;
739 
740  //Remove the selected feature so we will not use it again
741  if( removeFeaturesAtEachSpilt ){
742  for(size_t i=0; i<features.size(); i++){
743  if( features[i] == featureIndex ){
744  features.erase( features.begin()+i );
745  break;
746  }
747  }
748  }
749 
750  //Split the data into a left and right dataset
751  ClassificationData lhs(N);
752  ClassificationData rhs(N);
753 
754  //Reserve the memory to speed up the allocation of large datasets
755  lhs.reserve( M );
756  rhs.reserve( M );
757 
758  for(UINT i=0; i<M; i++){
759  if( node->predict( trainingData[i].getSample() ) ){
760  rhs.addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
761  }else lhs.addSample(trainingData[i].getClassLabel(), trainingData[i].getSample());
762  }
763 
764  //Clear the parent dataset so we do not run out of memory with very large datasets (with very deep trees)
765  trainingData.clear();
766 
767  //Get the new node IDs for the children
768  UINT leftNodeID = ++nodeID;
769  UINT rightNodeID = ++nodeID;
770 
771  //Run the recursive tree building on the children
772  node->setLeftChild( buildTree( lhs, node, features, classLabels, leftNodeID ) );
773  node->setRightChild( buildTree( rhs, node, features, classLabels, rightNodeID ) );
774 
775  //Build the null clusters for the rhs and lhs nodes if null rejection is enabled
776  if( useNullRejection ){
777  nodeClusters[ leftNodeID ] = lhs.getMean();
778  nodeClusters[ rightNodeID ] = rhs.getMean();
779  }
780 
781  return node;
782 }
783 
784 double DecisionTree::getNodeDistance( const VectorDouble &x, const UINT nodeID ){
785 
786  //Use the node ID to find the node cluster
787  std::map< UINT,VectorDouble >::iterator iter = nodeClusters.find( nodeID );
788 
789  //If we failed to find a match, return NAN
790  if( iter == nodeClusters.end() ) return NAN;
791 
792  //Compute the distance between the input and the node cluster
793  return getNodeDistance( x, iter->second );
794 }
795 
796 double DecisionTree::getNodeDistance( const VectorDouble &x, const VectorDouble &y ){
797 
798  double distance = 0;
799  const size_t N = x.size();
800 
801  for(size_t i=0; i<N; i++){
802  distance += MLBase::SQR( x[i] - y[i] );
803  }
804 
805  //Return the squared Euclidean distance instead of actual Euclidean distance as this is faster and just as useful
806  return distance;
807 }
808 
810 
811  string word;
812 
813  file >> word;
814  if(word != "NumFeatures:"){
815  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find NumFeatures!" << endl;
816  return false;
817  }
818  file >> numInputDimensions;
819 
820  file >> word;
821  if(word != "NumClasses:"){
822  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find NumClasses!" << endl;
823  return false;
824  }
825  file >> numClasses;
826 
827  file >> word;
828  if(word != "UseScaling:"){
829  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find UseScaling!" << endl;
830  return false;
831  }
832  file >> useScaling;
833 
834  file >> word;
835  if(word != "UseNullRejection:"){
836  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find UseNullRejection!" << endl;
837  return false;
838  }
839  file >> useNullRejection;
840 
842  if( useScaling ){
843  //Resize the ranges buffer
844  ranges.resize( numInputDimensions );
845 
846  file >> word;
847  if(word != "Ranges:"){
848  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the Ranges!" << endl;
849  return false;
850  }
851  for(UINT n=0; n<ranges.size(); n++){
852  file >> ranges[n].minValue;
853  file >> ranges[n].maxValue;
854  }
855  }
856 
857  file >> word;
858  if(word != "NumSplittingSteps:"){
859  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumSplittingSteps!" << endl;
860  return false;
861  }
862  file >> numSplittingSteps;
863 
864  file >> word;
865  if(word != "MinNumSamplesPerNode:"){
866  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MinNumSamplesPerNode!" << endl;
867  return false;
868  }
869  file >> minNumSamplesPerNode;
870 
871  file >> word;
872  if(word != "MaxDepth:"){
873  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MaxDepth!" << endl;
874  return false;
875  }
876  file >> maxDepth;
877 
878  file >> word;
879  if(word != "RemoveFeaturesAtEachSpilt:"){
880  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << endl;
881  return false;
882  }
883  file >> removeFeaturesAtEachSpilt;
884 
885  file >> word;
886  if(word != "TrainingMode:"){
887  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TrainingMode!" << endl;
888  return false;
889  }
890  file >> trainingMode;
891 
892  file >> word;
893  if(word != "TreeBuilt:"){
894  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TreeBuilt!" << endl;
895  return false;
896  }
897  file >> trained;
898 
899  if( trained ){
900  file >> word;
901  if(word != "Tree:"){
902  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the Tree!" << endl;
903  return false;
904  }
905 
906  //Create a new DTree
907  tree = new DecisionTreeNode;
908 
909  if( tree == NULL ){
910  clear();
911  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to create new DecisionTreeNode!" << endl;
912  return false;
913  }
914 
915  tree->setParent( NULL );
916  if( !tree->loadFromFile( file ) ){
917  clear();
918  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to load tree from file!" << endl;
919  return false;
920  }
921  }
922 
923  return true;
924 }
925 
926 bool DecisionTree::loadLegacyModelFromFile_v2( fstream &file ){
927 
928  string word;
929 
930  //Load the base settings from the file
932  Classifier::errorLog << "loadModelFromFile(string filename) - Failed to load base settings from file!" << endl;
933  return false;
934  }
935 
936  file >> word;
937  if(word != "NumSplittingSteps:"){
938  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumSplittingSteps!" << endl;
939  return false;
940  }
941  file >> numSplittingSteps;
942 
943  file >> word;
944  if(word != "MinNumSamplesPerNode:"){
945  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MinNumSamplesPerNode!" << endl;
946  return false;
947  }
948  file >> minNumSamplesPerNode;
949 
950  file >> word;
951  if(word != "MaxDepth:"){
952  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MaxDepth!" << endl;
953  return false;
954  }
955  file >> maxDepth;
956 
957  file >> word;
958  if(word != "RemoveFeaturesAtEachSpilt:"){
959  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << endl;
960  return false;
961  }
962  file >> removeFeaturesAtEachSpilt;
963 
964  file >> word;
965  if(word != "TrainingMode:"){
966  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TrainingMode!" << endl;
967  return false;
968  }
969  file >> trainingMode;
970 
971  file >> word;
972  if(word != "TreeBuilt:"){
973  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TreeBuilt!" << endl;
974  return false;
975  }
976  file >> trained;
977 
978  if( trained ){
979  file >> word;
980  if(word != "Tree:"){
981  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the Tree!" << endl;
982  return false;
983  }
984 
985  //Create a new DTree
986  tree = new DecisionTreeNode;
987 
988  if( tree == NULL ){
989  clear();
990  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to create new DecisionTreeNode!" << endl;
991  return false;
992  }
993 
994  tree->setParent( NULL );
995  if( !tree->loadFromFile( file ) ){
996  clear();
997  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to load tree from file!" << endl;
998  return false;
999  }
1000 
1001  //Recompute the null rejection thresholds
1003 
1004  //Resize the prediction results to make sure it is setup for realtime prediction
1005  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
1006  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1007  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
1008  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
1009  }
1010 
1011  return true;
1012 }
1013 
1014 bool DecisionTree::loadLegacyModelFromFile_v3( fstream &file ){
1015 
1016  string word;
1017 
1018  //Load the base settings from the file
1020  Classifier::errorLog << "loadModelFromFile(string filename) - Failed to load base settings from file!" << endl;
1021  return false;
1022  }
1023 
1024  file >> word;
1025  if(word != "NumSplittingSteps:"){
1026  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumSplittingSteps!" << endl;
1027  return false;
1028  }
1029  file >> numSplittingSteps;
1030 
1031  file >> word;
1032  if(word != "MinNumSamplesPerNode:"){
1033  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MinNumSamplesPerNode!" << endl;
1034  return false;
1035  }
1036  file >> minNumSamplesPerNode;
1037 
1038  file >> word;
1039  if(word != "MaxDepth:"){
1040  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the MaxDepth!" << endl;
1041  return false;
1042  }
1043  file >> maxDepth;
1044 
1045  file >> word;
1046  if(word != "RemoveFeaturesAtEachSpilt:"){
1047  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << endl;
1048  return false;
1049  }
1050  file >> removeFeaturesAtEachSpilt;
1051 
1052  file >> word;
1053  if(word != "TrainingMode:"){
1054  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TrainingMode!" << endl;
1055  return false;
1056  }
1057  file >> trainingMode;
1058 
1059  file >> word;
1060  if(word != "TreeBuilt:"){
1061  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the TreeBuilt!" << endl;
1062  return false;
1063  }
1064  file >> trained;
1065 
1066  if( trained ){
1067  file >> word;
1068  if(word != "Tree:"){
1069  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the Tree!" << endl;
1070  return false;
1071  }
1072 
1073  //Create a new DTree
1074  tree = new DecisionTreeNode;
1075 
1076  if( tree == NULL ){
1077  clear();
1078  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to create new DecisionTreeNode!" << endl;
1079  return false;
1080  }
1081 
1082  tree->setParent( NULL );
1083  if( !tree->loadFromFile( file ) ){
1084  clear();
1085  Classifier::errorLog << "loadModelFromFile(fstream &file) - Failed to load tree from file!" << endl;
1086  return false;
1087  }
1088 
1089  //Load the null rejection data if needed
1090  if( useNullRejection ){
1091 
1092  UINT numNodes = 0;
1093  classClusterMean.resize( numClasses );
1094  classClusterStdDev.resize( numClasses );
1095 
1096  file >> word;
1097  if(word != "ClassClusterMean:"){
1098  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the ClassClusterMean header!" << endl;
1099  return false;
1100  }
1101  for(UINT k=0; k<numClasses; k++){
1102  file >> classClusterMean[k];
1103  }
1104 
1105  file >> word;
1106  if(word != "ClassClusterStdDev:"){
1107  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the ClassClusterStdDev header!" << endl;
1108  return false;
1109  }
1110  for(UINT k=0; k<numClasses; k++){
1111  file >> classClusterStdDev[k];
1112  }
1113 
1114  file >> word;
1115  if(word != "NumNodes:"){
1116  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NumNodes header!" << endl;
1117  return false;
1118  }
1119  file >> numNodes;
1120 
1121  file >> word;
1122  if(word != "NodeClusters:"){
1123  Classifier::errorLog << "loadModelFromFile(string filename) - Could not find the NodeClusters header!" << endl;
1124  return false;
1125  }
1126 
1127  UINT nodeID = 0;
1128  VectorDouble cluster( numInputDimensions );
1129  for(UINT i=0; i<numNodes; i++){
1130 
1131  //load the nodeID
1132  file >> nodeID;
1133 
1134  for(UINT j=0; j<numInputDimensions; j++){
1135  file >> cluster[j];
1136  }
1137 
1138  //Add the cluster to the cluster nodes map
1139  nodeClusters[ nodeID ] = cluster;
1140  }
1141 
1142  //Recompute the null rejection thresholds
1144  }
1145 
1146  //Resize the prediction results to make sure it is setup for realtime prediction
1147  maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
1148  bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
1149  classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
1150  classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
1151  }
1152 
1153  return true;
1154 }
1155 
DecisionTree(const DecisionTreeNode &decisionTreeNode=DecisionTreeClusterNode(), const UINT minNumSamplesPerNode=5, const UINT maxDepth=10, const bool removeFeaturesAtEachSpilt=false, const UINT trainingMode=BEST_ITERATIVE_SPILT, const UINT numSplittingSteps=100, const bool useScaling=false)
This class implements a basic Decision Tree classifier. Decision Trees are conceptually simple classi...
virtual Node * deepCopyNode() const
Definition: Node.cpp:267
UINT getDepth() const
Definition: Node.cpp:299
virtual bool deepCopyFrom(const Classifier *classifier)
UINT getPredictedNodeID() const
Definition: Node.cpp:307
VectorDouble getMean() const
vector< UINT > getClassLabels() const
static unsigned int getMaxIndex(const std::vector< double > &x)
Definition: Util.cpp:276
bool copyBaseVariables(const Classifier *classifier)
Definition: Classifier.cpp:91
virtual ~DecisionTree(void)
Definition: AdaBoost.cpp:25
virtual bool loadFromFile(fstream &file)
Definition: Node.cpp:173
bool setLeafNode(const UINT nodeSize, const VectorDouble &classProbabilities)
virtual bool getModel(ostream &stream) const
bool loadBaseSettingsFromFile(fstream &file)
Definition: Classifier.cpp:301
virtual bool saveToFile(fstream &file) const
Definition: Node.cpp:131
virtual bool getModel(ostream &stream) const
Definition: Node.cpp:111
virtual bool recomputeNullRejectionThresholds()
string getNodeType() const
Definition: Node.cpp:295
virtual bool loadModelFromFile(fstream &file)
virtual bool train_(ClassificationData &trainingData)
bool saveBaseSettingsToFile(fstream &file) const
Definition: Classifier.cpp:254
double scale(const double &x, const double &minSource, const double &maxSource, const double &minTarget, const double &maxTarget, const bool constrain=false)
Definition: MLBase.h:339
bool loadLegacyModelFromFile_v1(fstream &file)
DecisionTreeNode * deepCopy() const
bool scale(const double minTarget, const double maxTarget)
DecisionTreeNode * deepCopyTree() const
bool setDecisionTreeNode(const DecisionTreeNode &node)
DecisionTreeNode * deepCopyDecisionTreeNode() const
virtual bool clear()
Definition: Classifier.cpp:140
virtual bool predict(const VectorDouble &x)
Definition: Node.cpp:59
Node * createNewInstance() const
Definition: Node.cpp:38
virtual bool clear()
virtual bool predict(const VectorDouble &x, VectorDouble &classLikelihoods)
DecisionTree & operator=(const DecisionTree &rhs)
bool getTrained() const
Definition: MLBase.cpp:223
vector< MinMax > getRanges() const
virtual bool clear()
Definition: Node.cpp:69
virtual bool saveModelToFile(fstream &file) const
string getClassifierType() const
Definition: Classifier.cpp:159
virtual bool predict_(VectorDouble &inputVector)
static Node * createInstanceFromString(string const &nodeType)
Definition: Node.cpp:28
const DecisionTreeNode * getTree() const
virtual bool computeBestSpilt(const UINT &trainingMode, const UINT &numSplittingSteps, const ClassificationData &trainingData, const vector< UINT > &features, const vector< UINT > &classLabels, UINT &featureIndex, double &minError)