7 RegisterNode< DecisionTreeThresholdNode > DecisionTreeThresholdNode::registerModule(
"DecisionTreeThresholdNode");
10 nodeType =
"DecisionTreeThresholdNode";
23 if( x[ featureIndex ] >= threshold )
return true;
54 for(UINT i=0; i<depth; i++) tab +=
"\t";
56 stream << tab <<
"depth: " << depth <<
" nodeSize: " << nodeSize <<
" featureIndex: " << featureIndex <<
" threshold " << threshold <<
" isLeafNode: " << isLeafNode << endl;
57 stream << tab <<
"ClassProbabilities: ";
58 for(UINT i=0; i<classProbabilities.size(); i++){
59 stream << classProbabilities[i] <<
"\t";
63 if( leftChild != NULL ){
64 stream << tab <<
"LeftChild: " << endl;
68 if( rightChild != NULL ){
69 stream << tab <<
"RightChild: " << endl;
86 node->isLeafNode = isLeafNode;
87 node->nodeID = nodeID;
88 node->predictedNodeID = predictedNodeID;
89 node->nodeSize = nodeSize;
90 node->featureIndex = featureIndex;
91 node->threshold = threshold;
92 node->classProbabilities = classProbabilities;
97 node->leftChild->setParent( node );
103 node->rightChild->setParent( node );
106 return dynamic_cast< Node*
>( node );
122 this->nodeSize = nodeSize;
123 this->featureIndex = featureIndex;
124 this->threshold = threshold;
125 this->classProbabilities = classProbabilities;
129 bool DecisionTreeThresholdNode::computeBestSpiltBestIterativeSpilt(
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const vector< UINT > &features,
const vector< UINT > &classLabels, UINT &featureIndex,
double &minError ){
132 const UINT N = (UINT)features.size();
133 const UINT K = (UINT)classLabels.size();
135 if( N == 0 )
return false;
137 minError = numeric_limits<double>::max();
138 UINT bestFeatureIndex = 0;
139 double bestThreshold = 0;
144 double giniIndexL = 0;
145 double giniIndexR = 0;
148 vector< UINT > groupIndex(M);
149 VectorDouble groupCounter(2,0);
150 vector< MinMax > ranges = trainingData.
getRanges();
155 for(UINT n=0; n<N; n++){
156 minRange = ranges[n].minValue;
157 maxRange = ranges[n].maxValue;
158 step = (maxRange-minRange)/
double(numSplittingSteps);
159 threshold = minRange;
160 featureIndex = features[n];
161 while( threshold <= maxRange ){
164 groupCounter[0] = groupCounter[1] = 0;
165 classProbabilities.setAllValues(0);
166 for(UINT i=0; i<M; i++){
167 groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
168 groupCounter[ groupIndex[i] ]++;
169 classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
173 for(UINT k=0; k<K; k++){
174 classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
175 classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
179 giniIndexL = giniIndexR = 0;
180 for(UINT k=0; k<K; k++){
181 giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
182 giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
184 weightL = groupCounter[0]/M;
185 weightR = groupCounter[1]/M;
186 error = (giniIndexL*weightL) + (giniIndexR*weightR);
189 if( error < minError ){
191 bestThreshold = threshold;
192 bestFeatureIndex = featureIndex;
201 featureIndex = bestFeatureIndex;
204 set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
209 bool DecisionTreeThresholdNode::computeBestSpiltBestRandomSpilt(
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const vector< UINT > &features,
const vector< UINT > &classLabels, UINT &featureIndex,
double &minError ){
211 const UINT M = trainingData.getNumSamples();
212 const UINT N = (UINT)features.size();
213 const UINT K = (UINT)classLabels.size();
215 if( N == 0 )
return false;
217 minError = numeric_limits<double>::max();
218 UINT bestFeatureIndex = 0;
219 double bestThreshold = 0;
221 double giniIndexL = 0;
222 double giniIndexR = 0;
226 vector< UINT > groupIndex(M);
227 VectorDouble groupCounter(2,0);
228 vector< MinMax > ranges = trainingData.getRanges();
230 MatrixDouble classProbabilities(K,2);
233 for(UINT n=0; n<N; n++){
234 featureIndex = features[n];
235 for(UINT m=0; m<numSplittingSteps; m++){
237 threshold = random.getRandomNumberUniform(ranges[n].minValue,ranges[n].maxValue);
240 groupCounter[0] = groupCounter[1] = 0;
241 classProbabilities.setAllValues(0);
242 for(UINT i=0; i<M; i++){
243 groupIndex[i] = trainingData[ i ][ featureIndex ] >= threshold ? 1 : 0;
244 groupCounter[ groupIndex[i] ]++;
245 classProbabilities[ getClassLabelIndexValue(trainingData[i].getClassLabel(),classLabels) ][ groupIndex[i] ]++;
249 for(UINT k=0; k<K; k++){
250 classProbabilities[k][0] = groupCounter[0]>0 ? classProbabilities[k][0]/groupCounter[0] : 0;
251 classProbabilities[k][1] = groupCounter[1]>0 ? classProbabilities[k][1]/groupCounter[1] : 0;
255 giniIndexL = giniIndexR = 0;
256 for(UINT k=0; k<K; k++){
257 giniIndexL += classProbabilities[k][0] * (1.0-classProbabilities[k][0]);
258 giniIndexR += classProbabilities[k][1] * (1.0-classProbabilities[k][1]);
260 weightL = groupCounter[0]/M;
261 weightR = groupCounter[1]/M;
262 error = (giniIndexL*weightL) + (giniIndexR*weightR);
265 if( error < minError ){
267 bestThreshold = threshold;
268 bestFeatureIndex = featureIndex;
274 featureIndex = bestFeatureIndex;
277 set(M,featureIndex,bestThreshold,trainingData.getClassProbabilities(classLabels));
286 errorLog <<
"saveParametersToFile(fstream &file) - File is not open!" << endl;
292 errorLog <<
"saveParametersToFile(fstream &file) - Failed to save DecisionTreeNode parameters to file!" << endl;
297 file <<
"FeatureIndex: " << featureIndex << endl;
298 file <<
"Threshold: " << threshold << endl;
307 errorLog <<
"loadParametersFromFile(fstream &file) - File is not open!" << endl;
313 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to load DecisionTreeNode parameters from file!" << endl;
321 if( word !=
"FeatureIndex:" ){
322 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to find FeatureIndex header!" << endl;
325 file >> featureIndex;
328 if( word !=
"Threshold:" ){
329 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to find Threshold header!" << endl;
virtual bool predict(const VectorDouble &x)
virtual bool getModel(ostream &stream) const
virtual Node * deepCopyNode() const
DecisionTreeThresholdNode * deepCopy() const
double getThreshold() const
virtual bool saveParametersToFile(fstream &file) const
DecisionTreeThresholdNode()
virtual Node * deepCopyNode() const
UINT getNumSamples() const
virtual bool getModel(ostream &stream) const
virtual bool saveParametersToFile(fstream &file) const
virtual ~DecisionTreeThresholdNode()
bool set(const UINT nodeSize, const UINT featureIndex, const double threshold, const VectorDouble &classProbabilities)
virtual bool loadParametersFromFile(fstream &file)
vector< MinMax > getRanges() const
virtual bool print() const
This file implements a DecisionTreeThresholdNode, which is a specific type of node used for a Decisio...
virtual bool loadParametersFromFile(fstream &file)
UINT getFeatureIndex() const