26 RegisterNode< RegressionTreeNode > RegressionTreeNode::registerModule(
"RegressionTreeNode");
29 RegisterRegressifierModule< RegressionTree > RegressionTree::registerModule(
"RegressionTree");
31 RegressionTree::RegressionTree(
const UINT numSplittingSteps,
const UINT minNumSamplesPerNode,
const UINT maxDepth,
const bool removeFeaturesAtEachSpilt,
const UINT trainingMode,
const bool useScaling,
const double minRMSErrorPerNode)
34 this->numSplittingSteps = numSplittingSteps;
35 this->minNumSamplesPerNode = minNumSamplesPerNode;
36 this->maxDepth = maxDepth;
37 this->removeFeaturesAtEachSpilt = removeFeaturesAtEachSpilt;
38 this->trainingMode = trainingMode;
39 this->useScaling = useScaling;
41 Regressifier::classType =
"RegressionTree";
42 regressifierType = Regressifier::classType;
43 Regressifier::debugLog.setProceedingText(
"[DEBUG RegressionTree]");
44 Regressifier::errorLog.setProceedingText(
"[ERROR RegressionTree]");
45 Regressifier::trainingLog.setProceedingText(
"[TRAINING RegressionTree]");
46 Regressifier::warningLog.setProceedingText(
"[WARNING RegressionTree]");
52 Regressifier::classType =
"RegressionTree";
53 regressifierType = Regressifier::classType;
54 Regressifier::debugLog.setProceedingText(
"[DEBUG RegressionTree]");
55 Regressifier::errorLog.setProceedingText(
"[ERROR RegressionTree]");
56 Regressifier::trainingLog.setProceedingText(
"[TRAINING RegressionTree]");
57 Regressifier::warningLog.setProceedingText(
"[WARNING RegressionTree]");
76 this->numSplittingSteps = rhs.numSplittingSteps;
77 this->minNumSamplesPerNode = rhs.minNumSamplesPerNode;
78 this->maxDepth = rhs.maxDepth;
79 this->removeFeaturesAtEachSpilt = rhs.removeFeaturesAtEachSpilt;
80 this->trainingMode = rhs.trainingMode;
91 if( regressifier == NULL )
return false;
105 this->numSplittingSteps = ptr->numSplittingSteps;
106 this->minNumSamplesPerNode = ptr->minNumSamplesPerNode;
107 this->maxDepth = ptr->maxDepth;
108 this->removeFeaturesAtEachSpilt = ptr->removeFeaturesAtEachSpilt;
109 this->trainingMode = ptr->trainingMode;
128 Regressifier::errorLog <<
"train_(RegressionData &trainingData) - Training data has zero samples!" << endl;
132 numInputDimensions = N;
133 numOutputDimensions = T;
140 trainingData.
scale(0, 1);
144 vector< UINT > features(N);
145 for(UINT i=0; i<N; i++){
151 tree = buildTree( trainingData, NULL, features, nodeID );
155 Regressifier::errorLog <<
"train_(RegressionData &trainingData) - Failed to build tree!" << endl;
168 Regressifier::errorLog <<
"predict_(VectorDouble &inputVector) - Model Not Trained!" << endl;
173 Regressifier::errorLog <<
"predict_(VectorDouble &inputVector) - Tree pointer is null!" << endl;
177 if( inputVector.size() != numInputDimensions ){
178 Regressifier::errorLog <<
"predict_(VectorDouble &inputVector) - The size of the input vector (" << inputVector.size() <<
") does not match the num features in the model (" << numInputDimensions << endl;
183 for(UINT n=0; n<numInputDimensions; n++){
184 inputVector[n] =
scale(inputVector[n], inputVectorRanges[n].minValue, inputVectorRanges[n].maxValue, 0, 1);
188 if( !tree->
predict( inputVector, regressionData ) ){
189 Regressifier::errorLog <<
"predict_(VectorDouble &inputVector) - Failed to predict!" << endl;
212 return tree->
print();
220 Regressifier::errorLog <<
"saveModelToFile(fstream &file) - The file is not open!" << endl;
225 file <<
"GRT_REGRESSION_TREE_MODEL_FILE_V1.0\n";
229 Regressifier::errorLog <<
"saveModelToFile(fstream &file) - Failed to save classifier base settings to file!" << endl;
233 file <<
"NumSplittingSteps: " << numSplittingSteps << endl;
234 file <<
"MinNumSamplesPerNode: " << minNumSamplesPerNode << endl;
235 file <<
"MaxDepth: " << maxDepth << endl;
236 file <<
"RemoveFeaturesAtEachSpilt: " << removeFeaturesAtEachSpilt << endl;
237 file <<
"TrainingMode: " << trainingMode << endl;
238 file <<
"TreeBuilt: " << (tree != NULL ? 1 : 0) << endl;
243 Regressifier::errorLog <<
"saveModelToFile(fstream &file) - Failed to save tree to file!" << endl;
257 Regressifier::errorLog <<
"loadModelFromFile(string filename) - Could not open file to load model" << endl;
265 if(word !=
"GRT_REGRESSION_TREE_MODEL_FILE_V1.0"){
266 Regressifier::errorLog <<
"loadModelFromFile(string filename) - Could not find Model File Header" << endl;
272 Regressifier::errorLog <<
"loadModelFromFile(string filename) - Failed to load base settings from file!" << endl;
277 if(word !=
"NumSplittingSteps:"){
278 Regressifier::errorLog <<
"loadModelFromFile(string filename) - Could not find the NumSplittingSteps!" << endl;
281 file >> numSplittingSteps;
284 if(word !=
"MinNumSamplesPerNode:"){
285 Regressifier::errorLog <<
"loadModelFromFile(string filename) - Could not find the MinNumSamplesPerNode!" << endl;
288 file >> minNumSamplesPerNode;
291 if(word !=
"MaxDepth:"){
292 Regressifier::errorLog <<
"loadModelFromFile(string filename) - Could not find the MaxDepth!" << endl;
298 if(word !=
"RemoveFeaturesAtEachSpilt:"){
299 Regressifier::errorLog <<
"loadModelFromFile(string filename) - Could not find the RemoveFeaturesAtEachSpilt!" << endl;
302 file >> removeFeaturesAtEachSpilt;
305 if(word !=
"TrainingMode:"){
306 Regressifier::errorLog <<
"loadModelFromFile(string filename) - Could not find the TrainingMode!" << endl;
309 file >> trainingMode;
312 if(word !=
"TreeBuilt:"){
313 Regressifier::errorLog <<
"loadModelFromFile(string filename) - Could not find the TreeBuilt!" << endl;
321 Regressifier::errorLog <<
"loadModelFromFile(string filename) - Could not find the Tree!" << endl;
330 Regressifier::errorLog <<
"loadModelFromFile(fstream &file) - Failed to create new RegressionTreeNode!" << endl;
334 tree->setParent( NULL );
337 Regressifier::errorLog <<
"loadModelFromFile(fstream &file) - Failed to load tree from file!" << endl;
372 VectorDouble regressionData(T);
393 node->initNode( parent, depth, nodeID );
396 if( features.size() == 0 || M < minNumSamplesPerNode || depth >= maxDepth ){
399 node->setIsLeafNode(
true );
402 computeNodeRegressionData( trainingData, regressionData );
407 Regressifier::trainingLog <<
"Reached leaf node. Depth: " << depth <<
" NumSamples: " << trainingData.
getNumSamples() << endl;
413 UINT featureIndex = 0;
414 double threshold = 0;
416 if( !computeBestSpilt( trainingData, features, featureIndex, threshold, minError ) ){
421 Regressifier::trainingLog <<
"Depth: " << depth <<
" FeatureIndex: " << featureIndex <<
" Threshold: " << threshold <<
" MinError: " << minError << endl;
426 computeNodeRegressionData( trainingData, regressionData );
429 node->
set( trainingData.
getNumSamples(), featureIndex, threshold, regressionData );
431 Regressifier::trainingLog <<
"Reached leaf node. Depth: " << depth <<
" NumSamples: " << M << endl;
437 node->
set( trainingData.
getNumSamples(), featureIndex, threshold, regressionData );
440 if( removeFeaturesAtEachSpilt ){
441 for(
size_t i=0; i<features.size(); i++){
442 if( features[i] == featureIndex ){
443 features.erase( features.begin()+i );
450 RegressionData lhs(N,T);
451 RegressionData rhs(N,T);
453 for(UINT i=0; i<M; i++){
454 if( node->
predict( trainingData[i].getInputVector() ) ){
455 rhs.addSample(trainingData[i].getInputVector(), trainingData[i].getTargetVector());
456 }
else lhs.addSample(trainingData[i].getInputVector(), trainingData[i].getTargetVector());
460 node->setLeftChild( buildTree( lhs, node, features, nodeID ) );
461 node->setRightChild( buildTree( rhs, node, features, nodeID ) );
466 bool RegressionTree::computeBestSpilt(
const RegressionData &trainingData,
const vector< UINT > &features, UINT &featureIndex,
double &threshold,
double &minError ){
468 switch( trainingMode ){
469 case BEST_ITERATIVE_SPILT:
470 return computeBestSpiltBestIterativeSpilt( trainingData, features, featureIndex, threshold, minError );
472 case BEST_RANDOM_SPLIT:
476 Regressifier::errorLog <<
"Uknown trainingMode!" << endl;
484 bool RegressionTree::computeBestSpiltBestIterativeSpilt(
const RegressionData &trainingData,
const vector< UINT > &features, UINT &featureIndex,
double &threshold,
double &minError ){
486 const UINT M = trainingData.getNumSamples();
487 const UINT N = (UINT)features.size();
489 if( N == 0 )
return false;
491 minError = numeric_limits<double>::max();
492 UINT bestFeatureIndex = 0;
494 double bestThreshold = 0;
499 vector< UINT > groupIndex(M);
500 VectorDouble groupCounter(2,0);
501 VectorDouble groupMean(2,0);
502 VectorDouble groupMSE(2,0);
503 vector< MinMax > ranges = trainingData.getInputRanges();
506 for(UINT n=0; n<N; n++){
507 minRange = ranges[n].minValue;
508 maxRange = ranges[n].maxValue;
509 step = (maxRange-minRange)/
double(numSplittingSteps);
510 threshold = minRange;
511 featureIndex = features[n];
512 while( threshold <= maxRange ){
515 for(UINT i=0; i<M; i++){
516 groupID = trainingData[i].getInputVector()[featureIndex] >= threshold ? 1 : 0;
517 groupIndex[i] = groupID;
518 groupMean[ groupID ] += trainingData[i].getInputVector()[featureIndex];
519 groupCounter[ groupID ]++;
521 groupMean[0] /= groupCounter[0] > 0 ? groupCounter[0] : 1;
522 groupMean[1] /= groupCounter[1] > 0 ? groupCounter[1] : 1;
525 for(UINT i=0; i<M; i++){
526 groupMSE[ groupIndex[i] ] += Regressifier::SQR( groupMean[ groupIndex[i] ] - trainingData[ i ].getInputVector()[features[n]] );
528 groupMSE[0] /= groupCounter[0] > 0 ? groupCounter[0] : 1;
529 groupMSE[1] /= groupCounter[1] > 0 ? groupCounter[1] : 1;
531 error = sqrt( groupMSE[0] + groupMSE[1] );
534 if( error < minError ){
536 bestThreshold = threshold;
537 bestFeatureIndex = featureIndex;
546 featureIndex = bestFeatureIndex;
547 threshold = bestThreshold;
625 bool RegressionTree::computeNodeRegressionData(
const RegressionData &trainingData, VectorDouble ®ressionData ){
627 const UINT M = trainingData.getNumSamples();
628 const UINT N = trainingData.getNumInputDimensions();
629 const UINT T = trainingData.getNumTargetDimensions();
632 Regressifier::errorLog <<
"computeNodeRegressionData(...) - Failed to compute regression data, there are zero training samples!" << endl;
637 regressionData.clear();
638 regressionData.resize( T, 0 );
641 for(
unsigned int j=0; j<N; j++){
642 for(
unsigned int i=0; i<M; i++){
643 regressionData[j] += trainingData[i].getTargetVector()[j];
645 regressionData[j] /= M;
virtual Node * deepCopyNode() const
bool copyBaseVariables(const Regressifier *regressifier)
This class implements a basic Regression Tree.
virtual bool loadModelFromFile(fstream &file)
RegressionTree & operator=(const RegressionTree &rhs)
double minRMSErrorPerNode
virtual bool loadFromFile(fstream &file)
bool set(const UINT nodeSize, const UINT featureIndex, const double threshold, const VectorDouble ®ressionData)
double getMinRMSErrorPerNode() const
RegressionTree(const UINT numSplittingSteps=100, const UINT minNumSamplesPerNode=5, const UINT maxDepth=10, const bool removeFeaturesAtEachSpilt=false, const UINT trainingMode=BEST_ITERATIVE_SPILT, const bool useScaling=false, const double minRMSErrorPerNode=0.01)
virtual ~RegressionTree(void)
virtual bool saveToFile(fstream &file) const
bool loadBaseSettingsFromFile(fstream &file)
vector< MinMax > getInputRanges() const
bool saveBaseSettingsToFile(fstream &file) const
virtual bool predict(const VectorDouble &x)
const RegressionTreeNode * getTree() const
virtual bool saveModelToFile(fstream &file) const
virtual bool train_(RegressionData &trainingData)
virtual bool predict_(VectorDouble &inputVector)
UINT getNumSamples() const
double scale(const double &x, const double &minSource, const double &maxSource, const double &minTarget, const double &maxTarget, const bool constrain=false)
virtual bool deepCopyFrom(const Regressifier *regressifier)
string getRegressifierType() const
bool setMinRMSErrorPerNode(const double minRMSErrorPerNode)
virtual bool print() const
vector< MinMax > getTargetRanges() const
virtual bool predict(const VectorDouble &x)
RegressionTreeNode * deepCopyTree() const
UINT getNumTargetDimensions() const
UINT getNumInputDimensions() const
bool scale(const double minTarget, const double maxTarget)
virtual bool print() const