28 baseType = BASE_TYPE_NOT_SET;
29 numInputDimensions = 0;
30 numOutputDimensions = 0;
33 validationSetSize = 20;
36 useValidationSet =
false;
37 randomiseTrainingOrder =
true;
38 rootMeanSquaredTrainingError = 0;
39 totalSquaredTrainingError = 0;
49 errorLog <<
"copyMLBaseVariables(MLBase *mlBase) - mlBase pointer is NULL!" << endl;
54 errorLog <<
"copyMLBaseVariables(MLBase *mlBase) - Failed to copy GRT Base variables!" << endl;
58 this->trained = mlBase->trained;
59 this->useScaling = mlBase->useScaling;
60 this->baseType = mlBase->baseType;
61 this->numInputDimensions = mlBase->numInputDimensions;
62 this->numOutputDimensions = mlBase->numOutputDimensions;
63 this->minNumEpochs = mlBase->minNumEpochs;
64 this->maxNumEpochs = mlBase->maxNumEpochs;
65 this->validationSetSize = mlBase->validationSetSize;
66 this->minChange = mlBase->minChange;
67 this->learningRate = mlBase->learningRate;
68 this->rootMeanSquaredTrainingError = mlBase->rootMeanSquaredTrainingError;
69 this->totalSquaredTrainingError = mlBase->totalSquaredTrainingError;
70 this->useValidationSet = mlBase->useValidationSet;
71 this->randomiseTrainingOrder = mlBase->randomiseTrainingOrder;
72 this->numTrainingIterationsToConverge = mlBase->numTrainingIterationsToConverge;
73 this->trainingResults = mlBase->trainingResults;
74 this->trainingResultsObserverManager = mlBase->trainingResultsObserverManager;
75 this->testResultsObserverManager = mlBase->testResultsObserverManager;
120 numInputDimensions = 0;
121 numOutputDimensions = 0;
122 numTrainingIterationsToConverge = 0;
123 rootMeanSquaredTrainingError = 0;
124 totalSquaredTrainingError = 0;
125 trainingResults.clear();
137 if( !trained )
return false;
140 file.open(filename.c_str(), std::ios::out);
160 file.open(filename.c_str(), std::ios::in);
177 std::ostringstream stream;
194 return numTrainingIterationsToConverge;
208 return validationSetSize;
216 return rootMeanSquaredTrainingError;
220 return totalSquaredTrainingError;
238 if( maxNumEpochs == 0 ){
239 warningLog <<
"setMaxNumEpochs(const UINT maxNumEpochs) - The maxNumEpochs must be greater than 0!" << endl;
242 this->maxNumEpochs = maxNumEpochs;
247 this->minNumEpochs = minNumEpochs;
253 warningLog <<
"setMinChange(const double minChange) - The minChange must be greater than or equal to 0!" << endl;
256 this->minChange = minChange;
261 if( learningRate > 0 ){
262 this->learningRate = learningRate;
270 if( validationSetSize > 0 && validationSetSize < 100 ){
271 this->validationSetSize = validationSetSize;
275 warningLog <<
"setValidationSetSize(const UINT validationSetSize) - The validation size must be in the range [1 99]!" << endl;
281 this->useValidationSet = useValidationSet;
286 this->randomiseTrainingOrder = randomiseTrainingOrder;
291 return trainingResultsObserverManager.registerObserver( observer );
295 return testResultsObserverManager.registerObserver( observer );
299 return trainingResultsObserverManager.removeObserver( observer );
303 return testResultsObserverManager.removeObserver( observer );
307 return trainingResultsObserverManager.removeAllObservers();
311 return testResultsObserverManager.removeAllObservers();
315 return trainingResultsObserverManager.notifyObservers( data );
319 return testResultsObserverManager.notifyObservers( data );
331 return trainingResults;
336 if( !file.is_open() ){
337 errorLog <<
"saveBaseSettingsToFile(fstream &file) - The file is not open!" << endl;
341 file <<
"Trained: " << trained << endl;
342 file <<
"UseScaling: " << useScaling << endl;
343 file <<
"NumInputDimensions: " << numInputDimensions << endl;
344 file <<
"NumOutputDimensions: " << numOutputDimensions << endl;
345 file <<
"NumTrainingIterationsToConverge: " << numTrainingIterationsToConverge << endl;
346 file <<
"MinNumEpochs: " << minNumEpochs << endl;
347 file <<
"MaxNumEpochs: " << maxNumEpochs << endl;
348 file <<
"ValidationSetSize: " << validationSetSize << endl;
349 file <<
"LearningRate: " << learningRate << endl;
350 file <<
"MinChange: " << minChange << endl;
351 file <<
"UseValidationSet: " << useValidationSet << endl;
352 file <<
"RandomiseTrainingOrder: " << randomiseTrainingOrder << endl;
362 if( !file.is_open() ){
363 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - The file is not open!" << endl;
371 if( word !=
"Trained:" ){
372 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read Trained header!" << endl;
379 if( word !=
"UseScaling:" ){
380 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read UseScaling header!" << endl;
387 if( word !=
"NumInputDimensions:" ){
388 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumInputDimensions header!" << endl;
391 file >> numInputDimensions;
395 if( word !=
"NumOutputDimensions:" ){
396 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumOutputDimensions header!" << endl;
399 file >> numOutputDimensions;
403 if( word !=
"NumTrainingIterationsToConverge:" ){
404 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read NumTrainingIterationsToConverge header!" << endl;
407 file >> numTrainingIterationsToConverge;
411 if( word !=
"MinNumEpochs:" ){
412 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MinNumEpochs header!" << endl;
415 file >> minNumEpochs;
419 if( word !=
"MaxNumEpochs:" ){
420 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MaxNumEpochs header!" << endl;
423 file >> maxNumEpochs;
427 if( word !=
"ValidationSetSize:" ){
428 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read ValidationSetSize header!" << endl;
431 file >> validationSetSize;
435 if( word !=
"LearningRate:" ){
436 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read LearningRate header!" << endl;
439 file >> learningRate;
443 if( word !=
"MinChange:" ){
444 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read MinChange header!" << endl;
451 if( word !=
"UseValidationSet:" ){
452 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read UseValidationSet header!" << endl;
455 file >> useValidationSet;
459 if( word !=
"RandomiseTrainingOrder:" ){
460 errorLog <<
"loadBaseSettingsFromFile(fstream &file) - Failed to read RandomiseTrainingOrder header!" << endl;
463 file >> randomiseTrainingOrder;
vector< TrainingResult > getTrainingResults() const
virtual bool saveModelToFile(string filename) const
UINT getMaxNumEpochs() const
virtual bool getModel(ostream &stream) const
double getRootMeanSquaredTrainingError() const
virtual bool print() const
bool setMaxNumEpochs(const UINT maxNumEpochs)
virtual bool loadModelFromFile(string filename)
bool loadBaseSettingsFromFile(fstream &file)
bool saveBaseSettingsToFile(fstream &file) const
bool getModelTrained() const
bool copyMLBaseVariables(const MLBase *mlBase)
virtual bool train(ClassificationData trainingData)
UINT getNumInputFeatures() const
bool setMinChange(const double minChange)
UINT getNumInputDimensions() const
bool enableScaling(bool useScaling)
bool setValidationSetSize(const UINT validationSetSize)
double getLearningRate() const
virtual bool map_(VectorDouble &inputVector)
double getTotalSquaredTrainingError() const
bool setRandomiseTrainingOrder(const bool randomiseTrainingOrder)
bool registerTrainingResultsObserver(Observer< TrainingResult > &observer)
This is the main base class that all GRT machine learning algorithms should inherit from...
bool notifyTrainingResultsObservers(const TrainingResult &data)
virtual bool predict(VectorDouble inputVector)
UINT getMinNumEpochs() const
virtual bool predict_(VectorDouble &inputVector)
bool copyGRTBaseVariables(const GRTBase *GRTBase)
bool setUseValidationSet(const bool useValidationSet)
bool getIsBaseTypeClusterer() const
MLBase * getMLBasePointer()
bool removeAllTestObservers()
UINT getValidationSetSize() const
virtual bool map(VectorDouble inputVector)
bool registerTestResultsObserver(Observer< TestInstanceResult > &observer)
bool getIsBaseTypeRegressifier() const
virtual bool load(const string filename)
UINT getNumOutputDimensions() const
bool removeTrainingResultsObserver(const Observer< TrainingResult > &observer)
UINT getNumTrainingIterationsToConverge() const
bool removeAllTrainingObservers()
bool setLearningRate(double learningRate)
bool getIsBaseTypeClassifier() const
virtual string getModelAsString() const
bool notifyTestResultsObservers(const TestInstanceResult &data)
bool setMinNumEpochs(const UINT minNumEpochs)
virtual bool train_(ClassificationData &trainingData)
bool getScalingEnabled() const
bool removeTestResultsObserver(const Observer< TestInstanceResult > &observer)
virtual bool save(const string filename) const