GestureRecognitionToolkit  Version: 1.0 Revision: 04-03-15
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
ClassificationData.cpp
1 /*
2 GRT MIT License
3 Copyright (c) <2012> <Nicholas Gillian, Media Lab, MIT>
4 
5 Permission is hereby granted, free of charge, to any person obtaining a copy of this software
6 and associated documentation files (the "Software"), to deal in the Software without restriction,
7 including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
9 subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in all copies or substantial
12 portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
15 LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
17 WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
18 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 */
20 
21 #include "ClassificationData.h"
22 
23 using namespace GRT;
24 
25 ClassificationData::ClassificationData(const UINT numDimensions,const string datasetName,const string infoText){
26  this->datasetName = datasetName;
27  this->numDimensions = numDimensions;
28  this->infoText = infoText;
29  totalNumSamples = 0;
30  crossValidationSetup = false;
31  useExternalRanges = false;
32  allowNullGestureClass = true;
33  if( numDimensions > 0 ) setNumDimensions( numDimensions );
34  infoLog.setProceedingText("[ClassificationData]");
35  debugLog.setProceedingText("[DEBUG ClassificationData]");
36  errorLog.setProceedingText("[ERROR ClassificationData]");
37  warningLog.setProceedingText("[WARNING ClassificationData]");
38 }
39 
41  *this = rhs;
42 }
43 
45 }
46 
48  if( this != &rhs){
49  this->datasetName = rhs.datasetName;
50  this->infoText = rhs.infoText;
51  this->numDimensions = rhs.numDimensions;
52  this->totalNumSamples = rhs.totalNumSamples;
53  this->kFoldValue = rhs.kFoldValue;
54  this->crossValidationSetup = rhs.crossValidationSetup;
55  this->useExternalRanges = rhs.useExternalRanges;
56  this->allowNullGestureClass = rhs.allowNullGestureClass;
57  this->externalRanges = rhs.externalRanges;
58  this->classTracker = rhs.classTracker;
59  this->data = rhs.data;
60  this->crossValidationIndexs = rhs.crossValidationIndexs;
61  this->infoLog = rhs.infoLog;
62  this->debugLog = rhs.debugLog;
63  this->errorLog = rhs.errorLog;
64  this->warningLog = rhs.warningLog;
65  }
66  return *this;
67 }
68 
70  totalNumSamples = 0;
71  data.clear();
72  classTracker.clear();
73  crossValidationSetup = false;
74  crossValidationIndexs.clear();
75 }
76 
77 bool ClassificationData::setNumDimensions(const UINT numDimensions){
78 
79  if( numDimensions > 0 ){
80  //Clear any previous training data
81  clear();
82 
83  //Set the dimensionality of the data
84  this->numDimensions = numDimensions;
85 
86  //Clear the external ranges
87  useExternalRanges = false;
88  externalRanges.clear();
89 
90  return true;
91  }
92 
93  errorLog << "setNumDimensions(const UINT numDimensions) - The number of dimensions of the dataset must be greater than zero!" << endl;
94  return false;
95 }
96 
97 bool ClassificationData::setDatasetName(const string datasetName){
98 
99  //Make sure there are no spaces in the string
100  if( datasetName.find(" ") == string::npos ){
101  this->datasetName = datasetName;
102  return true;
103  }
104 
105  errorLog << "setDatasetName(const string datasetName) - The dataset name cannot contain any spaces!" << endl;
106  return false;
107 }
108 
109 bool ClassificationData::setInfoText(const string infoText){
110  this->infoText = infoText;
111  return true;
112 }
113 
114 bool ClassificationData::setClassNameForCorrespondingClassLabel(const string className,const UINT classLabel){
115 
116  for(UINT i=0; i<classTracker.size(); i++){
117  if( classTracker[i].classLabel == classLabel ){
118  classTracker[i].className = className;
119  return true;
120  }
121  }
122 
123  errorLog << "setClassNameForCorrespondingClassLabel(const string className,const UINT classLabel) - Failed to find class with label: " << classLabel << endl;
124  return false;
125 }
126 
127 bool ClassificationData::setAllowNullGestureClass(const bool allowNullGestureClass){
128  this->allowNullGestureClass = allowNullGestureClass;
129  return true;
130 }
131 
132 bool ClassificationData::addSample(const UINT classLabel,const VectorDouble &sample){
133 
134  if( sample.size() != numDimensions ){
135  errorLog << "addSample(const UINT classLabel, VectorDouble &sample) - the size of the new sample (" << sample.size() << ") does not match the number of dimensions of the dataset (" << numDimensions << ")" << endl;
136  return false;
137  }
138 
139  //The class label must be greater than zero (as zero is used for the null rejection class label
140  if( classLabel == GRT_DEFAULT_NULL_CLASS_LABEL && !allowNullGestureClass ){
141  errorLog << "addSample(const UINT classLabel, VectorDouble &sample) - the class label can not be 0!" << endl;
142  return false;
143  }
144 
145  //The dataset has changed so flag that any previous cross validation setup will now not work
146  crossValidationSetup = false;
147  crossValidationIndexs.clear();
148 
149  ClassificationSample newSample(classLabel,sample);
150  data.push_back( newSample );
151  totalNumSamples++;
152 
153  if( classTracker.size() == 0 ){
154  ClassTracker tracker(classLabel,1);
155  classTracker.push_back(tracker);
156  }else{
157  bool labelFound = false;
158  for(UINT i=0; i<classTracker.size(); i++){
159  if( classLabel == classTracker[i].classLabel ){
160  classTracker[i].counter++;
161  labelFound = true;
162  break;
163  }
164  }
165  if( !labelFound ){
166  ClassTracker tracker(classLabel,1);
167  classTracker.push_back(tracker);
168  }
169  }
170 
171  //Update the class labels
172  sortClassLabels();
173 
174  return true;
175 }
176 
177 bool ClassificationData::removeSample( const UINT index ){
178 
179  if( totalNumSamples == 0 ){
180  warningLog << "removeSample( const UINT index ) - Failed to remove sample, the training dataset is empty!" << endl;
181  return false;
182  }
183 
184  if( index >= totalNumSamples ){
185  warningLog << "removeSample( const UINT index ) - Failed to remove sample, the index is out of bounds! Number of training samples: " << totalNumSamples << " index: " << index << endl;
186  return false;
187  }
188 
189  //The dataset has changed so flag that any previous cross validation setup will now not work
190  crossValidationSetup = false;
191  crossValidationIndexs.clear();
192 
193  //Find the corresponding class ID for the last training example
194  UINT classLabel = data[ index ].getClassLabel();
195 
196  //Remove the training example from the buffer
197  data.erase( data.begin()+index );
198 
199  totalNumSamples = (UINT)data.size();
200 
201  //Remove the value from the counter
202  for(size_t i=0; i<classTracker.size(); i++){
203  if( classTracker[i].classLabel == classLabel ){
204  classTracker[i].counter--;
205  break;
206  }
207  }
208 
209  return true;
210 }
211 
213 
214  if( totalNumSamples == 0 ){
215  warningLog << "removeLastSample() - Failed to remove sample, the training dataset is empty!" << endl;
216  return false;
217  }
218 
219  return removeSample( totalNumSamples-1 );
220 }
221 
222 bool ClassificationData::reserve(const UINT N){
223 
224  data.reserve( N );
225 
226  if( data.capacity() >= N ) return true;
227 
228  return false;
229 }
230 
232  return removeClass( classLabel );
233 }
234 
235 bool ClassificationData::addClass(const UINT classLabel,const std::string className){
236 
237  //Check to make sure the class label does not exist
238  for(size_t i=0; i<classTracker.size(); i++){
239  if( classTracker[i].classLabel == classLabel ){
240  warningLog << "addClass(const UINT classLabel,const std::string className) - Failed to add class, it already exists! Class label: " << classLabel << endl;
241  return false;
242  }
243  }
244 
245  //Add the class label to the class tracker
246  classTracker.push_back( ClassTracker(classLabel,0,className) );
247 
248  //Sort the class labels
249  sortClassLabels();
250 
251  return true;
252 }
253 
254 UINT ClassificationData::removeClass(const UINT classLabel){
255 
256  UINT numExamplesRemoved = 0;
257  UINT numExamplesToRemove = 0;
258 
259  //The dataset has changed so flag that any previous cross validation setup will now not work
260  crossValidationSetup = false;
261  crossValidationIndexs.clear();
262 
263  //Find out how many training examples we need to remove
264  for(UINT i=0; i<classTracker.size(); i++){
265  if( classTracker[i].classLabel == classLabel ){
266  numExamplesToRemove = classTracker[i].counter;
267  classTracker.erase(classTracker.begin()+i);
268  break;
269  }
270  }
271 
272  //Remove the samples with the matching class ID
273  if( numExamplesToRemove > 0 ){
274  UINT i=0;
275  while( numExamplesRemoved < numExamplesToRemove ){
276  if( data[i].getClassLabel() == classLabel ){
277  data.erase(data.begin()+i);
278  numExamplesRemoved++;
279  }else if( ++i == data.size() ) break;
280  }
281  }
282 
283  totalNumSamples = (UINT)data.size();
284 
285  return numExamplesRemoved;
286 }
287 
288 bool ClassificationData::relabelAllSamplesWithClassLabel(const UINT oldClassLabel,const UINT newClassLabel){
289  bool oldClassLabelFound = false;
290  bool newClassLabelAllReadyExists = false;
291  UINT indexOfOldClassLabel = 0;
292  UINT indexOfNewClassLabel = 0;
293 
294  //Find out how many training examples we need to relabel
295  for(UINT i=0; i<classTracker.size(); i++){
296  if( classTracker[i].classLabel == oldClassLabel ){
297  indexOfOldClassLabel = i;
298  oldClassLabelFound = true;
299  }
300  if( classTracker[i].classLabel == newClassLabel ){
301  indexOfNewClassLabel = i;
302  newClassLabelAllReadyExists = true;
303  }
304  }
305 
306  //If the old class label was not found then we can't do anything
307  if( !oldClassLabelFound ){
308  return false;
309  }
310 
311  //Relabel the old class labels
312  for(UINT i=0; i<totalNumSamples; i++){
313  if( data[i].getClassLabel() == oldClassLabel ){
314  data[i].setClassLabel(newClassLabel);
315  }
316  }
317 
318  //Update the class tracler
319  if( newClassLabelAllReadyExists ){
320  //Add the old sample count to the new sample count
321  classTracker[ indexOfNewClassLabel ].counter += classTracker[ indexOfOldClassLabel ].counter;
322  }else{
323  //Create a new class tracker
324  classTracker.push_back( ClassTracker(newClassLabel,classTracker[ indexOfOldClassLabel ].counter,classTracker[ indexOfOldClassLabel ].className) );
325  }
326 
327  //Erase the old class tracker
328  classTracker.erase( classTracker.begin() + indexOfOldClassLabel );
329 
330  //Sort the class labels
331  sortClassLabels();
332 
333  return true;
334 }
335 
336 bool ClassificationData::setExternalRanges(const vector< MinMax > &externalRanges, const bool useExternalRanges){
337 
338  if( externalRanges.size() != numDimensions ) return false;
339 
340  this->externalRanges = externalRanges;
341  this->useExternalRanges = useExternalRanges;
342 
343  return true;
344 }
345 
346 bool ClassificationData::enableExternalRangeScaling(const bool useExternalRanges){
347  if( externalRanges.size() == numDimensions ){
348  this->useExternalRanges = useExternalRanges;
349  return true;
350  }
351  return false;
352 }
353 
354 bool ClassificationData::scale(const double minTarget,const double maxTarget){
355  vector< MinMax > ranges = getRanges();
356  return scale(ranges,minTarget,maxTarget);
357 }
358 
359 bool ClassificationData::scale(const vector<MinMax> &ranges,const double minTarget,const double maxTarget){
360  if( ranges.size() != numDimensions ) return false;
361 
362  //Scale the training data
363  for(UINT i=0; i<totalNumSamples; i++){
364  for(UINT j=0; j<numDimensions; j++){
365  data[i][j] = Util::scale(data[i][j],ranges[j].minValue,ranges[j].maxValue,minTarget,maxTarget);
366  }
367  }
368 
369  return true;
370 }
371 
372 bool ClassificationData::save(const string &filename) const{
373 
374  //Check if the file should be saved as a csv file
375  if( Util::stringEndsWith( filename, ".csv" ) ){
376  return saveDatasetToCSVFile( filename );
377  }
378 
379  //Otherwise save it as a custom GRT file
380  return saveDatasetToFile( filename );
381 }
382 
383 bool ClassificationData::load(const string &filename){
384 
385  //Check if the file should be loaded as a csv file
386  if( Util::stringEndsWith( filename, ".csv" ) ){
387  return loadDatasetFromCSVFile( filename );
388  }
389 
390  //Otherwise save it as a custom GRT file
391  return loadDatasetFromFile( filename );
392 }
393 
394 bool ClassificationData::saveDatasetToFile(const string &filename) const{
395 
396  std::fstream file;
397  file.open(filename.c_str(), std::ios::out);
398 
399  if( !file.is_open() ){
400  return false;
401  }
402 
403  file << "GRT_LABELLED_CLASSIFICATION_DATA_FILE_V1.0\n";
404  file << "DatasetName: " << datasetName << endl;
405  file << "InfoText: " << infoText << endl;
406  file << "NumDimensions: " << numDimensions << endl;
407  file << "TotalNumExamples: " << totalNumSamples << endl;
408  file << "NumberOfClasses: " << classTracker.size() << endl;
409  file << "ClassIDsAndCounters: " << endl;
410 
411  for(UINT i=0; i<classTracker.size(); i++){
412  file << classTracker[i].classLabel << "\t" << classTracker[i].counter << "\t" << classTracker[i].className << endl;
413  }
414 
415  file << "UseExternalRanges: " << useExternalRanges << endl;
416 
417  if( useExternalRanges ){
418  for(UINT i=0; i<externalRanges.size(); i++){
419  file << externalRanges[i].minValue << "\t" << externalRanges[i].maxValue << endl;
420  }
421  }
422 
423  file << "Data:\n";
424 
425  for(UINT i=0; i<totalNumSamples; i++){
426  file << data[i].getClassLabel();
427  for(UINT j=0; j<numDimensions; j++){
428  file << "\t" << data[i][j];
429  }
430  file << endl;
431  }
432 
433  file.close();
434  return true;
435 }
436 
437 bool ClassificationData::loadDatasetFromFile(const string &filename){
438 
439  std::fstream file;
440  file.open(filename.c_str(), std::ios::in);
441  UINT numClasses = 0;
442  clear();
443 
444  if( !file.is_open() ){
445  errorLog << "loadDatasetFromFile(const string &filename) - could not open file!" << endl;
446  return false;
447  }
448 
449  string word;
450 
451  //Check to make sure this is a file with the Training File Format
452  file >> word;
453  if(word != "GRT_LABELLED_CLASSIFICATION_DATA_FILE_V1.0"){
454  errorLog << "loadDatasetFromFile(const string &filename) - could not find file header!" << endl;
455  file.close();
456  return false;
457  }
458 
459  //Get the name of the dataset
460  file >> word;
461  if(word != "DatasetName:"){
462  errorLog << "loadDatasetFromFile(const string &filename) - failed to find DatasetName header!" << endl;
463  errorLog << word << endl;
464  file.close();
465  return false;
466  }
467  file >> datasetName;
468 
469  file >> word;
470  if(word != "InfoText:"){
471  errorLog << "loadDatasetFromFile(const string &filename) - failed to find InfoText header!" << endl;
472  file.close();
473  return false;
474  }
475 
476  //Load the info text
477  file >> word;
478  infoText = "";
479  while( word != "NumDimensions:" ){
480  infoText += word + " ";
481  file >> word;
482  }
483 
484  //Get the number of dimensions in the training data
485  if( word != "NumDimensions:" ){
486  errorLog << "loadDatasetFromFile(const string &filename) - failed to find NumDimensions header!" << endl;
487  file.close();
488  return false;
489  }
490  file >> numDimensions;
491 
492  //Get the total number of training examples in the training data
493  file >> word;
494  if( word != "TotalNumTrainingExamples:" && word != "TotalNumExamples:" ){
495  errorLog << "loadDatasetFromFile(const string &filename) - failed to find TotalNumTrainingExamples header!" << endl;
496  file.close();
497  return false;
498  }
499  file >> totalNumSamples;
500 
501  //Get the total number of classes in the training data
502  file >> word;
503  if(word != "NumberOfClasses:"){
504  errorLog << "loadDatasetFromFile(string filename) - failed to find NumberOfClasses header!" << endl;
505  file.close();
506  return false;
507  }
508  file >> numClasses;
509 
510  //Resize the class counter buffer and load the counters
511  classTracker.resize(numClasses);
512 
513  //Get the total number of classes in the training data
514  file >> word;
515  if(word != "ClassIDsAndCounters:"){
516  errorLog << "loadDatasetFromFile(const string &filename) - failed to find ClassIDsAndCounters header!" << endl;
517  file.close();
518  return false;
519  }
520 
521  for(UINT i=0; i<classTracker.size(); i++){
522  file >> classTracker[i].classLabel;
523  file >> classTracker[i].counter;
524  file >> classTracker[i].className;
525  }
526 
527  //Check if the dataset should be scaled using external ranges
528  file >> word;
529  if(word != "UseExternalRanges:"){
530  errorLog << "loadDatasetFromFile(const string &filename) - failed to find UseExternalRanges header!" << endl;
531  file.close();
532  return false;
533  }
534  file >> useExternalRanges;
535 
536  //If we are using external ranges then load them
537  if( useExternalRanges ){
538  externalRanges.resize(numDimensions);
539  for(UINT i=0; i<externalRanges.size(); i++){
540  file >> externalRanges[i].minValue;
541  file >> externalRanges[i].maxValue;
542  }
543  }
544 
545  //Get the main training data
546  file >> word;
547  if( word != "LabelledTrainingData:" && word != "Data:"){
548  errorLog << "loadDatasetFromFile(const string &filename) - failed to find LabelledTrainingData header!" << endl;
549  file.close();
550  return false;
551  }
552 
553  ClassificationSample tempSample( numDimensions );
554  data.resize( totalNumSamples, tempSample );
555 
556  for(UINT i=0; i<totalNumSamples; i++){
557  UINT classLabel = 0;
558  VectorDouble sample(numDimensions,0);
559  file >> classLabel;
560  for(UINT j=0; j<numDimensions; j++){
561  file >> sample[j];
562  }
563  data[i].set(classLabel, sample);
564  }
565 
566  file.close();
567 
568  //Sort the class labels
569  sortClassLabels();
570 
571  return true;
572 }
573 
574 bool ClassificationData::saveDatasetToCSVFile(const string &filename) const{
575 
576  std::fstream file;
577  file.open(filename.c_str(), std::ios::out );
578 
579  if( !file.is_open() ){
580  return false;
581  }
582 
583  //Write the data to the CSV file
584  for(UINT i=0; i<totalNumSamples; i++){
585  file << data[i].getClassLabel();
586  for(UINT j=0; j<numDimensions; j++){
587  file << "," << data[i][j];
588  }
589  file << endl;
590  }
591 
592  file.close();
593 
594  return true;
595 }
596 
597 bool ClassificationData::loadDatasetFromCSVFile(const string &filename,const UINT classLabelColumnIndex){
598 
599  numDimensions = 0;
600  datasetName = "NOT_SET";
601  infoText = "";
602 
603  //Clear any previous data
604  clear();
605 
606  //Parse the CSV file
607  FileParser parser;
608 
609  if( !parser.parseCSVFile(filename,true) ){
610  errorLog << "loadDatasetFromCSVFile(const string &filename,const UINT classLabelColumnIndex) - Failed to parse CSV file!" << endl;
611  return false;
612  }
613 
614  if( !parser.getConsistentColumnSize() ){
615  errorLog << "loadDatasetFromCSVFile(const string &filename,const UINT classLabelColumnIndexe) - The CSV file does not have a consistent number of columns!" << endl;
616  return false;
617  }
618 
619  if( parser.getColumnSize() <= 1 ){
620  errorLog << "loadDatasetFromCSVFile(const string &filename,const UINT classLabelColumnIndex) - The CSV file does not have enough columns! It should contain at least two columns!" << endl;
621  return false;
622  }
623 
624  //Set the number of dimensions
625  numDimensions = parser.getColumnSize()-1;
626 
627  //Reserve the memory for the data
628  reserve( parser.getRowSize() );
629 
630  UINT classLabel = 0;
631  UINT j = 0;
632  UINT n = 0;
633  VectorDouble sample(numDimensions);
634  for(UINT i=0; i<parser.getRowSize(); i++){
635  //Get the class label
636  classLabel = Util::stringToInt( parser[i][classLabelColumnIndex] );
637 
638  //Get the sample data
639  j=0;
640  n=0;
641  while( j != numDimensions ){
642  if( n != classLabelColumnIndex ){
643  sample[j++] = Util::stringToDouble( parser[i][n] );
644  }
645  n++;
646  }
647 
648  //Add the labelled sample to the dataset
649  if( !addSample(classLabel, sample) ){
650  warningLog << "loadDatasetFromCSVFile(const string &filename,const UINT classLabelColumnIndex) - Could not add sample " << i << " to the dataset!" << endl;
651  }
652  }
653 
654  //Sort the class labels
655  sortClassLabels();
656 
657  return true;
658 }
659 
661 
662  cout << getStatsAsString();
663 
664  return true;
665 }
666 
668 
669  sort(classTracker.begin(),classTracker.end(),ClassTracker::sortByClassLabelAscending);
670 
671  return true;
672 }
673 
674 ClassificationData ClassificationData::partition(const UINT trainingSizePercentage,const bool useStratifiedSampling){
675 
676  //Partitions the dataset into a training dataset (which is kept by this instance of the ClassificationData) and
677  //a testing/validation dataset (which is return as a new instance of the ClassificationData). The trainingSizePercentage
678  //therefore sets the size of the data which remains in this instance and the remaining percentage of data is then added to
679  //the testing/validation dataset
680 
681  //The dataset has changed so flag that any previous cross validation setup will now not work
682  crossValidationSetup = false;
683  crossValidationIndexs.clear();
684 
685  ClassificationData trainingSet(numDimensions);
686  ClassificationData testSet(numDimensions);
687  trainingSet.setAllowNullGestureClass( allowNullGestureClass );
688  testSet.setAllowNullGestureClass( allowNullGestureClass );
689  vector< UINT > indexs( totalNumSamples );
690 
691  //Create the random partion indexs
692  Random random;
693  UINT randomIndex = 0;
694 
695  if( useStratifiedSampling ){
696  //Break the data into seperate classes
697  vector< vector< UINT > > classData( getNumClasses() );
698 
699  //Add the indexs to their respective classes
700  for(UINT i=0; i<totalNumSamples; i++){
701  classData[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
702  }
703 
704  //Randomize the order of the indexs in each of the class index buffers
705  for(UINT k=0; k<getNumClasses(); k++){
706  UINT numSamples = (UINT)classData[k].size();
707  for(UINT x=0; x<numSamples; x++){
708  //Pick a random index
709  randomIndex = random.getRandomNumberInt(0,numSamples);
710 
711  //Swap the indexs
712  SWAP(classData[k][ x ], classData[k][ randomIndex ]);
713  }
714  }
715 
716  //Reserve the memory
717  UINT numTrainingSamples = 0;
718  UINT numTestSamples = 0;
719 
720  for(UINT k=0; k<getNumClasses(); k++){
721  UINT numTrainingExamples = (UINT) floor( double(classData[k].size()) / 100.0 * double(trainingSizePercentage) );
722  UINT numTestExamples = ((UINT)classData[k].size())-numTrainingExamples;
723  numTrainingSamples += numTrainingExamples;
724  numTestSamples += numTestExamples;
725  }
726 
727  trainingSet.reserve( numTrainingSamples );
728  testSet.reserve( numTestSamples );
729 
730  //Loop over each class and add the data to the trainingSet and testSet
731  for(UINT k=0; k<getNumClasses(); k++){
732  UINT numTrainingExamples = (UINT) floor( double(classData[k].size()) / 100.0 * double(trainingSizePercentage) );
733 
734  //Add the data to the training and test sets
735  for(UINT i=0; i<numTrainingExamples; i++){
736  trainingSet.addSample( data[ classData[k][i] ].getClassLabel(), data[ classData[k][i] ].getSample() );
737  }
738  for(UINT i=numTrainingExamples; i<classData[k].size(); i++){
739  testSet.addSample( data[ classData[k][i] ].getClassLabel(), data[ classData[k][i] ].getSample() );
740  }
741  }
742  }else{
743 
744  const UINT numTrainingExamples = (UINT) floor( double(totalNumSamples) / 100.0 * double(trainingSizePercentage) );
745  //Create the random partion indexs
746  Random random;
747  UINT randomIndex = 0;
748  for(UINT i=0; i<totalNumSamples; i++) indexs[i] = i;
749  for(UINT x=0; x<totalNumSamples; x++){
750  //Pick a random index
751  randomIndex = random.getRandomNumberInt(0,totalNumSamples);
752 
753  //Swap the indexs
754  SWAP(indexs[ x ],indexs[ randomIndex ]);
755  }
756 
757  //Reserve the memory
758  trainingSet.reserve( numTrainingExamples );
759  testSet.reserve( totalNumSamples-numTrainingExamples );
760 
761  //Add the data to the training and test sets
762  for(UINT i=0; i<numTrainingExamples; i++){
763  trainingSet.addSample( data[ indexs[i] ].getClassLabel(), data[ indexs[i] ].getSample() );
764  }
765  for(UINT i=numTrainingExamples; i<totalNumSamples; i++){
766  testSet.addSample( data[ indexs[i] ].getClassLabel(), data[ indexs[i] ].getSample() );
767  }
768  }
769 
770  //Overwrite the training data in this instance with the training data of the trainingSet
771  *this = trainingSet;
772 
773  //Sort the class labels in this dataset
774  sortClassLabels();
775 
776  //Sort the class labels of the test dataset
777  testSet.sortClassLabels();
778 
779  return testSet;
780 }
781 
783 
784  if( labelledData.getNumDimensions() != numDimensions ){
785  errorLog << "merge(const ClassificationData &labelledData) - The number of dimensions in the labelledData (" << labelledData.getNumDimensions() << ") does not match the number of dimensions of this dataset (" << numDimensions << ")" << endl;
786  return false;
787  }
788 
789  //The dataset has changed so flag that any previous cross validation setup will now not work
790  crossValidationSetup = false;
791  crossValidationIndexs.clear();
792 
793  //Reserve the memory
794  reserve( getNumSamples() + labelledData.getNumSamples() );
795 
796  //Add the data from the labelledData to this instance
797  for(UINT i=0; i<labelledData.getNumSamples(); i++){
798  addSample(labelledData[i].getClassLabel(), labelledData[i].getSample());
799  }
800 
801  //Set the class names from the dataset
802  vector< ClassTracker > classTracker = labelledData.getClassTracker();
803  for(UINT i=0; i<classTracker.size(); i++){
804  setClassNameForCorrespondingClassLabel(classTracker[i].className, classTracker[i].classLabel);
805  }
806 
807  //Sort the class labels
808  sortClassLabels();
809 
810  return true;
811 }
812 
813 bool ClassificationData::spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling){
814 
815  crossValidationSetup = false;
816  crossValidationIndexs.clear();
817 
818  //K can not be zero
819  if( K > totalNumSamples ){
820  errorLog << "spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling) - K can not be zero!" << endl;
821  return false;
822  }
823 
824  //K can not be larger than the number of examples
825  if( K > totalNumSamples ){
826  errorLog << "spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling) - K can not be larger than the total number of samples in the dataset!" << endl;
827  return false;
828  }
829 
830  //K can not be larger than the number of examples in a specific class if the stratified sampling option is true
831  if( useStratifiedSampling ){
832  for(UINT c=0; c<classTracker.size(); c++){
833  if( K > classTracker[c].counter ){
834  errorLog << "spiltDataIntoKFolds(const UINT K,const bool useStratifiedSampling) - K can not be larger than the number of samples in any given class!" << endl;
835  return false;
836  }
837  }
838  }
839 
840  //Setup the dataset for k-fold cross validation
841  kFoldValue = K;
842  vector< UINT > indexs( totalNumSamples );
843 
844  //Work out how many samples are in each fold, the last fold might have more samples than the others
845  UINT numSamplesPerFold = (UINT) floor( totalNumSamples/double(K) );
846 
847  //Add the random indexs to each fold
848  crossValidationIndexs.resize(K);
849 
850  //Create the random partion indexs
851  Random random;
852  UINT randomIndex = 0;
853 
854  if( useStratifiedSampling ){
855  //Break the data into seperate classes
856  vector< vector< UINT > > classData( getNumClasses() );
857 
858  //Add the indexs to their respective classes
859  for(UINT i=0; i<totalNumSamples; i++){
860  classData[ getClassLabelIndexValue( data[i].getClassLabel() ) ].push_back( i );
861  }
862 
863  //Randomize the order of the indexs in each of the class index buffers
864  for(UINT c=0; c<getNumClasses(); c++){
865  UINT numSamples = (UINT)classData[c].size();
866  for(UINT x=0; x<numSamples; x++){
867  //Pick a random indexs
868  randomIndex = random.getRandomNumberInt(0,numSamples);
869 
870  //Swap the indexs
871  SWAP(classData[c][ x ] , classData[c][ randomIndex ]);
872  }
873  }
874 
875  //Loop over each of the k folds, at each fold add a sample from each class
876  vector< UINT >::iterator iter;
877  for(UINT c=0; c<getNumClasses(); c++){
878  iter = classData[ c ].begin();
879  UINT k = 0;
880  while( iter != classData[c].end() ){
881  crossValidationIndexs[ k ].push_back( *iter );
882  iter++;
883  k++;
884  k = k % K;
885  }
886  }
887 
888  }else{
889  //Randomize the order of the data
890  for(UINT i=0; i<totalNumSamples; i++) indexs[i] = i;
891  for(UINT x=0; x<totalNumSamples; x++){
892  //Pick a random index
893  randomIndex = random.getRandomNumberInt(0,totalNumSamples);
894 
895  //Swap the indexs
896  SWAP(indexs[ x ] , indexs[ randomIndex ]);
897  }
898 
899  UINT counter = 0;
900  UINT foldIndex = 0;
901  for(UINT i=0; i<totalNumSamples; i++){
902  //Add the index to the current fold
903  crossValidationIndexs[ foldIndex ].push_back( indexs[i] );
904 
905  //Move to the next fold if ready
906  if( ++counter == numSamplesPerFold && foldIndex < K-1 ){
907  foldIndex++;
908  counter = 0;
909  }
910  }
911  }
912 
913  crossValidationSetup = true;
914  return true;
915 
916 }
917 
919 
920  ClassificationData trainingData;
921  trainingData.setNumDimensions( numDimensions );
922  trainingData.setAllowNullGestureClass( allowNullGestureClass );
923 
924  if( !crossValidationSetup ){
925  errorLog << "getTrainingFoldData(const UINT foldIndex) - Cross Validation has not been setup! You need to call the spiltDataIntoKFolds(UINT K,bool useStratifiedSampling) function first before calling this function!" << endl;
926  return trainingData;
927  }
928 
929  if( foldIndex >= kFoldValue ) return trainingData;
930 
931  //Add the class labels to make sure they all exist
932  for(UINT k=0; k<getNumSamples(); k++){
933  trainingData.addClass( classTracker[k].classLabel, classTracker[k].className );
934  }
935 
936  //Add the data to the training set, this will consist of all the data that is NOT in the foldIndex
937  UINT index = 0;
938  for(UINT k=0; k<kFoldValue; k++){
939  if( k != foldIndex ){
940  for(UINT i=0; i<crossValidationIndexs[k].size(); i++){
941 
942  index = crossValidationIndexs[k][i];
943  trainingData.addSample( data[ index ].getClassLabel(), data[ index ].getSample() );
944  }
945  }
946  }
947 
948  //Sort the class labels
949  trainingData.sortClassLabels();
950 
951  return trainingData;
952 }
953 
955 
956  ClassificationData testData;
957  testData.setNumDimensions( numDimensions );
958  testData.setAllowNullGestureClass( allowNullGestureClass );
959 
960  if( !crossValidationSetup ) return testData;
961 
962  if( foldIndex >= kFoldValue ) return testData;
963 
964  //Add the class labels to make sure they all exist
965  for(UINT k=0; k<getNumSamples(); k++){
966  testData.addClass( classTracker[k].classLabel, classTracker[k].className );
967  }
968 
969  testData.reserve( (UINT)crossValidationIndexs[ foldIndex ].size() );
970 
971  //Add the data to the test fold
972  UINT index = 0;
973  for(UINT i=0; i<crossValidationIndexs[ foldIndex ].size(); i++){
974 
975  index = crossValidationIndexs[ foldIndex ][i];
976  testData.addSample( data[ index ].getClassLabel(), data[ index ].getSample() );
977  }
978 
979  //Sort the class labels
980  testData.sortClassLabels();
981 
982  return testData;
983 }
984 
986 
987  ClassificationData classData;
988  classData.setNumDimensions( this->numDimensions );
989  classData.setAllowNullGestureClass( allowNullGestureClass );
990 
991  //Reserve the memory for the class data
992  for(UINT i=0; i<classTracker.size(); i++){
993  if( classTracker[i].classLabel == classLabel ){
994  classData.reserve( classTracker[i].counter );
995  break;
996  }
997  }
998 
999  for(UINT i=0; i<totalNumSamples; i++){
1000  if( data[i].getClassLabel() == classLabel ){
1001  classData.addSample(classLabel, data[i].getSample());
1002  }
1003  }
1004 
1005  return classData;
1006 }
1007 
1009 
1010  Random rand;
1011  ClassificationData newDataset;
1012  newDataset.setNumDimensions( getNumDimensions() );
1013  newDataset.setAllowNullGestureClass( allowNullGestureClass );
1014  newDataset.setExternalRanges( externalRanges, useExternalRanges );
1015 
1016  if( numSamples == 0 ) numSamples = totalNumSamples;
1017 
1018  newDataset.reserve( numSamples );
1019 
1020  //Add all the class labels to the new dataset to ensure the dataset has a list of all the labels
1021  for(UINT k=0; k<getNumClasses(); k++){
1022  newDataset.addClass( classTracker[k].classLabel );
1023  }
1024 
1025  //Randomly select the training samples to add to the new data set
1026  UINT randomIndex;
1027  for(UINT i=0; i<numSamples; i++){
1028  randomIndex = rand.getRandomNumberInt(0, totalNumSamples);
1029  newDataset.addSample(data[randomIndex].getClassLabel(), data[randomIndex].getSample());
1030  }
1031 
1032  //Sort the class labels so they are in order
1033  newDataset.sortClassLabels();
1034 
1035  return newDataset;
1036 }
1037 
1039 
1040  //Turns the classification into a regression data to enable regression algorithms like the MLP to be used as a classifier
1041  //This sets the number of targets in the regression data equal to the number of classes in the classification data
1042  //The output of each regression training sample will then be all 0's, except for the index matching the classLabel, which will be 1
1043  //For this to work, the labelled classification data cannot have any samples with a classLabel of 0!
1044  RegressionData regressionData;
1045 
1046  if( totalNumSamples == 0 ){
1047  return regressionData;
1048  }
1049 
1050  const UINT numInputDimensions = numDimensions;
1051  const UINT numTargetDimensions = getNumClasses();
1052  regressionData.setInputAndTargetDimensions(numInputDimensions, numTargetDimensions);
1053 
1054  for(UINT i=0; i<totalNumSamples; i++){
1055  VectorDouble targetVector(numTargetDimensions,0);
1056 
1057  //Set the class index in the target vector to 1 and all other values in the target vector to 0
1058  UINT classLabel = data[i].getClassLabel();
1059 
1060  if( classLabel > 0 ){
1061  targetVector[ classLabel-1 ] = 1;
1062  }else{
1063  regressionData.clear();
1064  return regressionData;
1065  }
1066 
1067  regressionData.addSample(data[i].getSample(),targetVector);
1068  }
1069 
1070  return regressionData;
1071 }
1072 
1074 
1075  UnlabelledData unlabelledData;
1076 
1077  if( totalNumSamples == 0 ){
1078  return unlabelledData;
1079  }
1080 
1081  unlabelledData.setNumDimensions( numDimensions );
1082 
1083  for(UINT i=0; i<totalNumSamples; i++){
1084  unlabelledData.addSample( data[i].getSample() );
1085  }
1086 
1087  return unlabelledData;
1088 }
1089 
1091  UINT minClassLabel = numeric_limits< UINT >::max();
1092 
1093  for(UINT i=0; i<classTracker.size(); i++){
1094  if( classTracker[i].classLabel < minClassLabel ){
1095  minClassLabel = classTracker[i].classLabel;
1096  }
1097  }
1098 
1099  return minClassLabel;
1100 }
1101 
1102 
1104  UINT maxClassLabel = 0;
1105 
1106  for(UINT i=0; i<classTracker.size(); i++){
1107  if( classTracker[i].classLabel > maxClassLabel ){
1108  maxClassLabel = classTracker[i].classLabel;
1109  }
1110  }
1111 
1112  return maxClassLabel;
1113 }
1114 
1116  for(UINT k=0; k<classTracker.size(); k++){
1117  if( classTracker[k].classLabel == classLabel ){
1118  return k;
1119  }
1120  }
1121  warningLog << "getClassLabelIndexValue(UINT classLabel) - Failed to find class label: " << classLabel << " in class tracker!" << endl;
1122  return 0;
1123 }
1124 
1126 
1127  for(UINT i=0; i<classTracker.size(); i++){
1128  if( classTracker[i].classLabel == classLabel ){
1129  return classTracker[i].className;
1130  }
1131  }
1132 
1133  return "CLASS_LABEL_NOT_FOUND";
1134 }
1135 
1137  string statsText;
1138  statsText += "DatasetName:\t" + datasetName + "\n";
1139  statsText += "DatasetInfo:\t" + infoText + "\n";
1140  statsText += "Number of Dimensions:\t" + Util::toString( numDimensions ) + "\n";
1141  statsText += "Number of Samples:\t" + Util::toString( totalNumSamples ) + "\n";
1142  statsText += "Number of Classes:\t" + Util::toString( getNumClasses() ) + "\n";
1143  statsText += "ClassStats:\n";
1144 
1145  for(UINT k=0; k<getNumClasses(); k++){
1146  statsText += "ClassLabel:\t" + Util::toString( classTracker[k].classLabel );
1147  statsText += "\tNumber of Samples:\t" + Util::toString(classTracker[k].counter);
1148  statsText += "\tClassName:\t" + classTracker[k].className + "\n";
1149  }
1150 
1151  vector< MinMax > ranges = getRanges();
1152 
1153  statsText += "Dataset Ranges:\n";
1154  for(UINT j=0; j<ranges.size(); j++){
1155  statsText += "[" + Util::toString( j+1 ) + "] Min:\t" + Util::toString( ranges[j].minValue ) + "\tMax: " + Util::toString( ranges[j].maxValue ) + "\n";
1156  }
1157 
1158  return statsText;
1159 }
1160 
1161 vector<MinMax> ClassificationData::getRanges() const{
1162 
1163  //If the dataset should be scaled using the external ranges then return the external ranges
1164  if( useExternalRanges ) return externalRanges;
1165 
1166  vector< MinMax > ranges(numDimensions);
1167 
1168  //Otherwise return the min and max values for each column in the dataset
1169  if( totalNumSamples > 0 ){
1170  for(UINT j=0; j<numDimensions; j++){
1171  ranges[j].minValue = data[0][0];
1172  ranges[j].maxValue = data[0][0];
1173  for(UINT i=0; i<totalNumSamples; i++){
1174  if( data[i][j] < ranges[j].minValue ){ ranges[j].minValue = data[i][j]; } //Search for the min value
1175  else if( data[i][j] > ranges[j].maxValue ){ ranges[j].maxValue = data[i][j]; } //Search for the max value
1176  }
1177  }
1178  }
1179  return ranges;
1180 }
1181 
1182 vector< UINT > ClassificationData::getClassLabels() const{
1183  vector< UINT > classLabels( getNumClasses(), 0 );
1184 
1185  if( getNumClasses() == 0 ) return classLabels;
1186 
1187  for(UINT i=0; i<getNumClasses(); i++){
1188  classLabels[i] = classTracker[i].classLabel;
1189  }
1190 
1191  return classLabels;
1192 }
1193 
1195  vector< UINT > classSampleCounts( getNumClasses(), 0 );
1196 
1197  if( getNumSamples() == 0 ) return classSampleCounts;
1198 
1199  for(UINT i=0; i<getNumClasses(); i++){
1200  classSampleCounts[i] = classTracker[i].counter;
1201  }
1202 
1203  return classSampleCounts;
1204 }
1205 
1206 VectorDouble ClassificationData::getMean() const{
1207 
1208  VectorDouble mean(numDimensions,0);
1209 
1210  for(UINT j=0; j<numDimensions; j++){
1211  for(UINT i=0; i<totalNumSamples; i++){
1212  mean[j] += data[i][j];
1213  }
1214  mean[j] /= double(totalNumSamples);
1215  }
1216 
1217  return mean;
1218 }
1219 
1220 VectorDouble ClassificationData::getStdDev() const{
1221 
1222  VectorDouble mean = getMean();
1223  VectorDouble stdDev(numDimensions,0);
1224 
1225  for(UINT j=0; j<numDimensions; j++){
1226  for(UINT i=0; i<totalNumSamples; i++){
1227  stdDev[j] += SQR(data[i][j]-mean[j]);
1228  }
1229  stdDev[j] = sqrt( stdDev[j] / double(totalNumSamples-1) );
1230  }
1231 
1232  return stdDev;
1233 }
1234 
1235 MatrixDouble ClassificationData::getClassHistogramData(UINT classLabel,UINT numBins) const{
1236 
1237  const UINT M = getNumSamples();
1238  const UINT N = getNumDimensions();
1239 
1240  vector< MinMax > ranges = getRanges();
1241  vector< double > binRange(N);
1242  for(UINT i=0; i<ranges.size(); i++){
1243  binRange[i] = (ranges[i].maxValue-ranges[i].minValue)/double(numBins);
1244  }
1245 
1246  MatrixDouble histData(N,numBins);
1247  histData.setAllValues(0);
1248 
1249  double norm = 0;
1250  for(UINT i=0; i<M; i++){
1251  if( data[i].getClassLabel() == classLabel ){
1252  for(UINT j=0; j<N; j++){
1253  UINT binIndex = 0;
1254  bool binFound = false;
1255  for(UINT k=0; k<numBins-1; k++){
1256  if( data[i][j] >= ranges[i].minValue + (binRange[j]*k) && data[i][j] >= ranges[i].minValue + (binRange[j]*(k+1)) ){
1257  binIndex = k;
1258  binFound = true;
1259  break;
1260  }
1261  }
1262  if( !binFound ) binIndex = numBins-1;
1263  histData[j][binIndex]++;
1264  }
1265  norm++;
1266  }
1267  }
1268 
1269  if( norm == 0 ) return histData;
1270 
1271  //Is this the best way to normalize a multidimensional histogram???
1272  for(UINT i=0; i<histData.getNumRows(); i++){
1273  for(UINT j=0; j<histData.getNumCols(); j++){
1274  histData[i][j] /= norm;
1275  }
1276  }
1277 
1278  return histData;
1279 }
1280 
1282 
1283  MatrixDouble mean(getNumClasses(),numDimensions);
1284  VectorDouble counter(getNumClasses(),0);
1285 
1286  mean.setAllValues( 0 );
1287 
1288  for(UINT i=0; i<totalNumSamples; i++){
1289  UINT classIndex = getClassLabelIndexValue( data[i].getClassLabel() );
1290  for(UINT j=0; j<numDimensions; j++){
1291  mean[classIndex][j] += data[i][j];
1292  }
1293  counter[ classIndex ]++;
1294  }
1295 
1296  for(UINT k=0; k<getNumClasses(); k++){
1297  for(UINT j=0; j<numDimensions; j++){
1298  mean[k][j] = counter[j] > 0 ? mean[k][j]/counter[j] : 0;
1299  }
1300  }
1301 
1302  return mean;
1303 }
1304 
1306 
1307  MatrixDouble mean = getClassMean();
1308  MatrixDouble stdDev(getNumClasses(),numDimensions);
1309  VectorDouble counter(getNumClasses(),0);
1310 
1311  stdDev.setAllValues( 0 );
1312 
1313  for(UINT i=0; i<totalNumSamples; i++){
1314  UINT classIndex = getClassLabelIndexValue( data[i].getClassLabel() );
1315  for(UINT j=0; j<numDimensions; j++){
1316  stdDev[classIndex][j] += SQR(data[i][j]-mean[classIndex][j]);
1317  }
1318  counter[ classIndex ]++;
1319  }
1320 
1321  for(UINT k=0; k<getNumClasses(); k++){
1322  for(UINT j=0; j<numDimensions; j++){
1323  stdDev[k][j] = sqrt( stdDev[k][j] / double(counter[k]-1) );
1324  }
1325  }
1326 
1327  return stdDev;
1328 }
1329 
1331 
1332  VectorDouble mean = getMean();
1333  MatrixDouble covariance(numDimensions,numDimensions);
1334 
1335  for(UINT j=0; j<numDimensions; j++){
1336  for(UINT k=0; k<numDimensions; k++){
1337  for(UINT i=0; i<totalNumSamples; i++){
1338  covariance[j][k] += (data[i][j]-mean[j]) * (data[i][k]-mean[k]) ;
1339  }
1340  covariance[j][k] /= double(totalNumSamples-1);
1341  }
1342  }
1343 
1344  return covariance;
1345 }
1346 
1347 vector< MatrixDouble > ClassificationData::getHistogramData(UINT numBins) const{
1348  const UINT K = getNumClasses();
1349  vector< MatrixDouble > histData(K);
1350 
1351  for(UINT k=0; k<K; k++){
1352  histData[k] = getClassHistogramData( classTracker[k].classLabel, numBins );
1353  }
1354 
1355  return histData;
1356 }
1357 
1358 VectorDouble ClassificationData::getClassProbabilities() const {
1359  return getClassProbabilities( getClassLabels() );
1360 }
1361 
1362 VectorDouble ClassificationData::getClassProbabilities( const vector< UINT > &classLabels ) const {
1363  const UINT K = (UINT)classLabels.size();
1364  const UINT N = getNumClasses();
1365  double sum = 0;
1366  VectorDouble x(K,0);
1367  for(UINT k=0; k<K; k++){
1368  for(UINT n=0; n<N; n++){
1369  if( classLabels[k] == classTracker[n].classLabel ){
1370  x[k] = classTracker[n].counter;
1371  sum += classTracker[n].counter;
1372  break;
1373  }
1374  }
1375  }
1376 
1377  //Normalize the class probabilities
1378  if( sum > 0 ){
1379  for(UINT k=0; k<K; k++){
1380  x[k] /= sum;
1381  }
1382  }
1383 
1384  return x;
1385 }
1386 
1387 vector< UINT > ClassificationData::getClassDataIndexes(UINT classLabel) const{
1388 
1389  const UINT M = getNumSamples();
1390  const UINT K = getNumClasses();
1391  UINT N = 0;
1392 
1393  //Get the number of samples in the class
1394  for(UINT k=0; k<K; k++){
1395  if( classTracker[k].classLabel == classLabel){
1396  N = classTracker[k].counter;
1397  break;
1398  }
1399  }
1400 
1401  UINT index = 0;
1402  vector< UINT > classIndexes(N);
1403  for(UINT i=0; i<M; i++){
1404  if( data[i].getClassLabel() == classLabel ){
1405  classIndexes[index++] = i;
1406  }
1407  }
1408 
1409  return classIndexes;
1410 }
1411 
1413 
1414  const UINT M = getNumSamples();
1415  const UINT N = getNumDimensions();
1416  MatrixDouble d(M,N);
1417 
1418  for(UINT i=0; i<M; i++){
1419  for(UINT j=0; j<N; j++){
1420  d[i][j] = data[i][j];
1421  }
1422  }
1423 
1424  return d;
1425 }
1426 
1427 bool ClassificationData::generateGaussDataset( const std::string filename, const UINT numSamples, const UINT numClasses, const UINT numDimensions, const double range, const double sigma ){
1428 
1429  Random random;
1430 
1431  //Generate a simple model that will be used to generate the main dataset
1432  MatrixDouble model(numClasses,numDimensions);
1433  for(UINT k=0; k<numClasses; k++){
1434  for(UINT j=0; j<numDimensions; j++){
1435  model[k][j] = random.getRandomNumberUniform(-range,range);
1436  }
1437  }
1438 
1439  //Use the model above to generate the main dataset
1440  ClassificationData data;
1441  data.setNumDimensions( numDimensions );
1442 
1443  for(UINT i=0; i<numSamples; i++){
1444 
1445  //Randomly select which class this sample belongs to
1446  UINT k = random.getRandomNumberInt( 0, numClasses );
1447 
1448  //Generate a sample using the model (+ some Gaussian noise)
1449  vector< double > sample( numDimensions );
1450  for(UINT j=0; j<numDimensions; j++){
1451  sample[j] = model[k][j] + random.getRandomNumberGauss(0,sigma);
1452  }
1453 
1454  //By default in the GRT, the class label should not be 0, so add 1
1455  UINT classLabel = k + 1;
1456 
1457  //Add the labeled sample to the dataset
1458  data.addSample( classLabel, sample );
1459  }
1460 
1461  //Save the dataset to a CSV file
1462  return data.save( filename );
1463 }
1464 
VectorDouble getStdDev() const
double getRandomNumberGauss(double mu=0.0, double sigma=1.0)
Definition: Random.h:208
bool save(const string &filename) const
static std::string toString(const int &i)
Definition: Util.cpp:65
bool removeSample(const UINT index)
The ClassificationData is the main data structure for recording, labeling, managing, saving, and loading training data for supervised learning problems.
bool setAllowNullGestureClass(bool allowNullGestureClass)
bool setAllValues(const T &value)
Definition: Matrix.h:335
VectorDouble getMean() const
vector< UINT > getClassLabels() const
bool loadDatasetFromFile(const string &filename)
Definition: AdaBoost.cpp:25
ClassificationData getTrainingFoldData(const UINT foldIndex) const
bool setInputAndTargetDimensions(const UINT numInputDimensions, const UINT numTargetDimensions)
static double scale(const double &x, const double &minSource, const double &maxSource, const double &minTarget, const double &maxTarget, const bool constrain=false)
Definition: Util.cpp:44
bool setNumDimensions(UINT numDimensions)
ClassificationData & operator=(const ClassificationData &rhs)
unsigned int getNumCols() const
Definition: Matrix.h:538
bool addSample(const VectorDouble &sample)
UINT removeClass(const UINT classLabel)
MatrixDouble getClassStdDev() const
bool merge(const ClassificationData &labelledData)
vector< ClassTracker > getClassTracker() const
static double stringToDouble(const std::string &s)
Definition: Util.cpp:124
UINT getClassLabelIndexValue(const UINT classLabel) const
bool scale(const double minTarget, const double maxTarget)
static bool generateGaussDataset(const std::string filename, const UINT numSamples=10000, const UINT numClasses=10, const UINT numDimensions=3, const double range=10, const double sigma=1)
RegressionData reformatAsRegressionData() const
vector< UINT > getClassDataIndexes(const UINT classLabel) const
vector< UINT > getNumSamplesPerClass() const
ClassificationData(UINT numDimensions=0, string datasetName="NOT_SET", string infoText="")
bool saveDatasetToCSVFile(const string &filename) const
int getRandomNumberInt(int minRange, int maxRange)
Definition: Random.h:87
double getRandomNumberUniform(double minRange=0.0, double maxRange=1.0)
Definition: Random.h:197
MatrixDouble getDataAsMatrixDouble() const
UnlabelledData reformatAsUnlabelledData() const
MatrixDouble getCovarianceMatrix() const
bool setExternalRanges(const vector< MinMax > &externalRanges, const bool useExternalRanges=false)
unsigned int getNumRows() const
Definition: Matrix.h:531
vector< MinMax > getRanges() const
UINT eraseAllSamplesWithClassLabel(const UINT classLabel)
static int stringToInt(const std::string &s)
Definition: Util.cpp:117
bool addSample(const VectorDouble &inputVector, const VectorDouble &targetVector)
static bool stringEndsWith(const std::string &str, const std::string &ending)
Definition: Util.cpp:141
bool enableExternalRangeScaling(const bool useExternalRanges)
vector< MatrixDouble > getHistogramData(const UINT numBins) const
bool setInfoText(string infoText)
bool loadDatasetFromCSVFile(const string &filename, const UINT classLabelColumnIndex=0)
ClassificationData getTestFoldData(const UINT foldIndex) const
string getClassNameForCorrespondingClassLabel(const UINT classLabel) const
ClassificationData partition(const UINT partitionPercentage, const bool useStratifiedSampling=false)
bool spiltDataIntoKFolds(const UINT K, const bool useStratifiedSampling=false)
MatrixDouble getClassMean() const
bool addSample(UINT classLabel, const VectorDouble &sample)
ClassificationData getBootstrappedDataset(UINT numSamples=0) const
bool load(const string &filename)
bool setNumDimensions(const UINT numDimensions)
bool saveDatasetToFile(const string &filename) const
bool relabelAllSamplesWithClassLabel(const UINT oldClassLabel, const UINT newClassLabel)
bool setClassNameForCorrespondingClassLabel(string className, UINT classLabel)
bool addClass(const UINT classLabel, const std::string className="NOT_SET")
MatrixDouble getClassHistogramData(const UINT classLabel, const UINT numBins) const
bool setDatasetName(string datasetName)
ClassificationData getClassData(const UINT classLabel) const