31 #ifndef GRT_DECISION_TREE_NODE_HEADER
32 #define GRT_DECISION_TREE_NODE_HEADER
34 #include "../../CoreAlgorithms/Tree/Node.h"
35 #include "../../CoreAlgorithms/Tree/Tree.h"
36 #include "../../DataStructures/ClassificationData.h"
65 virtual bool predict(
const VectorDouble &x,VectorDouble &classLikelihoods);
81 virtual bool computeBestSpilt(
const UINT &trainingMode,
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const vector< UINT > &features,
const vector< UINT > &classLabels, UINT &featureIndex,
double &minError );
98 virtual bool getModel(ostream &stream)
const;
144 bool setLeafNode(
const UINT nodeSize,
const VectorDouble &classProbabilities );
165 virtual bool computeBestSpiltBestIterativeSpilt(
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const vector< UINT > &features,
const vector< UINT > &classLabels, UINT &featureIndex,
double &minError ){
167 errorLog <<
"computeBestSpiltBestIterativeSpilt(...) - Base class not overwritten!" << endl;
172 virtual bool computeBestSpiltBestRandomSpilt(
const UINT &numSplittingSteps,
const ClassificationData &trainingData,
const vector< UINT > &features,
const vector< UINT > &classLabels, UINT &featureIndex,
double &minError ){
174 errorLog <<
"computeBestSpiltBestRandomSpilt(...) - Base class not overwritten!" << endl;
188 if( !file.is_open() )
190 errorLog <<
"saveParametersToFile(fstream &file) - File is not open!" << endl;
195 file <<
"NodeSize: " << nodeSize << endl;
196 file <<
"NumClasses: " << classProbabilities.size() << endl;
197 file <<
"ClassProbabilities: ";
198 if( classProbabilities.size() > 0 ){
199 for(UINT i=0; i<classProbabilities.size(); i++){
200 file << classProbabilities[i];
201 if( i < classProbabilities.size()-1 ) file <<
"\t";
217 if( !file.is_open() )
219 errorLog <<
"loadParametersFromFile(fstream &file) - File is not open!" << endl;
223 classProbabilities.clear();
230 if( word !=
"NodeSize:" ){
231 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to find NodeSize header!" << endl;
237 if( word !=
"NumClasses:" ){
238 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to find NumClasses header!" << endl;
243 classProbabilities.resize( numClasses );
246 if( word !=
"ClassProbabilities:" ){
247 errorLog <<
"loadParametersFromFile(fstream &file) - Failed to find ClassProbabilities header!" << endl;
250 if( numClasses > 0 ){
251 for(UINT i=0; i<numClasses; i++){
252 file >> classProbabilities[i];
259 UINT getClassLabelIndexValue(UINT classLabel,
const vector< UINT > &classLabels)
const{
260 const UINT N = (UINT)classLabels.size();
261 for(UINT i=0; i<N; i++){
262 if( classLabel == classLabels[i] )
269 VectorDouble classProbabilities;
271 static RegisterNode< DecisionTreeNode > registerModule;
276 #endif //GRT_DECISION_TREE_NODE_HEADER
virtual bool saveParametersToFile(fstream &file) const
bool setLeafNode(const UINT nodeSize, const VectorDouble &classProbabilities)
DecisionTreeNode * deepCopy() const
virtual ~DecisionTreeNode()
virtual Node * deepCopyNode() const
bool setNodeSize(const UINT nodeSize)
virtual bool predict(const VectorDouble &x)
virtual bool predict(const VectorDouble &x, VectorDouble &classLikelihoods)
virtual bool getModel(ostream &stream) const
virtual bool loadParametersFromFile(fstream &file)
VectorDouble getClassProbabilities() const
UINT getNumClasses() const
bool setClassProbabilities(const VectorDouble &classProbabilities)
virtual bool computeBestSpilt(const UINT &trainingMode, const UINT &numSplittingSteps, const ClassificationData &trainingData, const vector< UINT > &features, const vector< UINT > &classLabels, UINT &featureIndex, double &minError)