Decision Tree

Description

This class implements a basic Decision Tree classifier. Decision Trees are conceptually simple classifiers that work well on even complex classification tasks. Decision Trees partition the feature space into a set of rectangular regions, classifying a new datum by finding which region it belongs to.

The Decision Tree algorithm is part of the GRT classification modules.

Advantages

The Decision Tree algorithm is a good algorithm to use for the classification of static postures and non-temporal pattern recognition. The main advantage of a Decision Tree is that the model created can easily be interrupted.

Disadvantages

The main limitation of the Decision Tree algorithm is that very large models will frequently overfit the training data. To prevent overfitting, you should experiment with the maximum depth of the Decision Tree.

Training Data Format

You should use the LabelledClassificationData data structure to train the Decision Tree classifier.

Example Code

This examples demonstrates how to initialize, train, and use the Decision Tree algorithm for classification. The example loads the data shown in the image below and uses this to train a Decision Tree model. The data is a recording of a Wii-mote being held in 5 different orientations, the top graph shows the raw accelerometer data from the recording (showing the x, y, and z accelerometer data), while the bottom graph shows the label recorded for each sample (you can see the 5 different classes in the label data). You can download the actual dataset in the Code & Resources section below.

Training Data
The data is a recording of a Wii-mote being held in 5 different orientations, the top graph shows the raw accelerometer data from the recording (showing the x, y, and z accelerometer data), while the bottom graph shows the label recorded for each sample (you can see the 5 different classes in the label data). WiiAccelerometerData.jpg
/*
 GRT DecisionTree

 This examples demonstrates how to initialize, train, and use the DecisionTree algorithm for classification.

 Decision Trees are conceptually simple classifiers that work well on even complex classification tasks.  
 Decision Trees partition the feature space into a set of rectangular regions, classifying a new datum by
 finding which region it belongs to.

 In this example we create an instance of a DecisionTree algorithm and then train a model using some pre-recorded training data.
 The trained DecisionTree model is then used to predict the class label of some test data.

 This example shows you how to:
 - Create and initialize the DecisionTree algorithm
 - Load some LabelledClassificationData from a file and partition the training data into a training dataset and a test dataset
 - Train a DecisionTree model using the training dataset
 - Test the DecisionTree model using the test dataset
 - Manually compute the accuracy of the classifier
*/


#include "GRT.h"
using namespace GRT;

int main(int argc, const char * argv[])
{
    //Create a new DecisionTree instance
    DecisionTree dTree;

    //Set the number of steps that will be used to choose the best splitting values
    //More steps will give you a better model, but will take longer to train
    dTree.setNumSplittingSteps( 100 );

    //Set the maximum depth of the tree
    dTree.setMaxDepth( 10 );

    //Set the minimum number of samples allowed per node
    dTree.setMinNumSamplesPerNode( 10 );

    //Load some training data to train the classifier
    LabelledClassificationData trainingData;

    if( !trainingData.loadDatasetFromFile("DecisionTreeTrainingData.txt") ){
        cout << "Failed to load training data!\n";
        return EXIT_FAILURE;
    }

    //Use 20% of the training dataset to create a test dataset
    LabelledClassificationData testData = trainingData.partition( 80 );

    //Train the classifier
    if( !dTree.train( trainingData ) ){
        cout << "Failed to train classifier!\n";
        return EXIT_FAILURE;
    }

    //Print the tree
    dTree.print();

    //Save the model to a file
    if( !dTree.saveModelToFile("DecisionTreeModel.txt") ){
        cout << "Failed to save the classifier model!\n";
        return EXIT_FAILURE;
    }

    //Load the model from a file
    if( !dTree.loadModelFromFile("DecisionTreeModel.txt") ){
        cout << "Failed to load the classifier model!\n";
        return EXIT_FAILURE;
    }

    //Test the accuracy of the model on the test data
    double accuracy = 0;
    for(UINT i=0; i<testData.getNumSamples(); i++){
        //Get the i'th test sample
        UINT classLabel = testData[i].getClassLabel();
        VectorDouble inputVector = testData[i].getSample();

        //Perform a prediction using the classifier
        bool predictSuccess = dTree.predict( inputVector );

        if( !predictSuccess ){
            cout << "Failed to perform prediction for test sampel: " << i <<"\n";
            return EXIT_FAILURE;
        }

        //Get the predicted class label
        UINT predictedClassLabel = dTree.getPredictedClassLabel();
        VectorDouble classLikelihoods = dTree.getClassLikelihoods();
        VectorDouble classDistances = dTree.getClassDistances();

        //Update the accuracy
        if( classLabel == predictedClassLabel ) accuracy++;

        cout << "TestSample: " << i <<  " ClassLabel: " << classLabel << " PredictedClassLabel: " << predictedClassLabel << endl;
    }

    cout << "Test Accuracy: " << accuracy/double(testData.getNumSamples())*100.0 << "%" << endl;

    return EXIT_SUCCESS;
}

Code & Resources

DecisionTreeExample.cpp DecisionTreeTrainingData.txt

Documentation

You can find the documentation for this class at DecisionTree documentation.