28 RegisterClassifierModule< ANBC > ANBC::registerModule(
"ANBC");
30 ANBC::ANBC(
bool useScaling,
bool useNullRejection,
double nullRejectionCoeff)
32 this->useScaling = useScaling;
33 this->useNullRejection = useNullRejection;
34 this->nullRejectionCoeff = nullRejectionCoeff;
35 supportsNullRejection =
true;
36 weightsDataSet =
false;
38 classifierType = classType;
39 classifierMode = STANDARD_CLASSIFIER_MODE;
40 debugLog.setProceedingText(
"[DEBUG ANBC]");
41 errorLog.setProceedingText(
"[ERROR ANBC]");
42 trainingLog.setProceedingText(
"[TRAINING ANBC]");
43 warningLog.setProceedingText(
"[WARNING ANBC]");
46 ANBC::ANBC(
const ANBC &rhs){
48 classifierType = classType;
49 classifierMode = STANDARD_CLASSIFIER_MODE;
50 debugLog.setProceedingText(
"[DEBUG ANBC]");
51 errorLog.setProceedingText(
"[ERROR ANBC]");
52 trainingLog.setProceedingText(
"[TRAINING ANBC]");
53 warningLog.setProceedingText(
"[WARNING ANBC]");
64 this->weightsDataSet = rhs.weightsDataSet;
65 this->weightsData = rhs.weightsData;
66 this->models = rhs.models;
76 if( classifier == NULL )
return false;
82 this->weightsDataSet = ptr->weightsDataSet;
83 this->weightsData = ptr->weightsData;
84 this->models = ptr->models;
87 return copyBaseVariables( classifier );
102 errorLog <<
"train_(ClassificationData &labelledTrainingData) - Training data has zero samples!" << endl;
106 if( weightsDataSet ){
107 if( weightsData.getNumDimensions() != N ){
108 errorLog <<
"train_(ClassificationData &labelledTrainingData) - The number of dimensions in the weights data (" << weightsData.getNumDimensions() <<
") is not equal to the number of dimensions of the training data (" << N <<
")" << endl;
113 numInputDimensions = N;
116 classLabels.resize(K);
117 ranges = labelledTrainingData.
getRanges();
122 labelledTrainingData.
scale(0, 1);
126 for(UINT k=0; k<numClasses; k++){
129 UINT classLabel = labelledTrainingData.
getClassTracker()[k].classLabel;
132 classLabels[k] = classLabel;
135 VectorDouble weights(numInputDimensions);
136 if( weightsDataSet ){
137 bool weightsFound =
false;
138 for(UINT i=0; i<weightsData.getNumSamples(); i++){
139 if( weightsData[i].getClassLabel() == classLabel ){
140 weights = weightsData[i].getSample();
147 errorLog <<
"train_(ClassificationData &labelledTrainingData) - Failed to find the weights for class " << classLabel << endl;
152 for(UINT j=0; j<numInputDimensions; j++) weights[j] = 1.0;
160 for(UINT i=0; i<data.getNumRows(); i++){
161 for(UINT j=0; j<data.getNumCols(); j++){
162 data[i][j] = classData[i][j];
167 models[k].gamma = nullRejectionCoeff;
168 if( !models[k].train(classLabel,data,weights) ){
169 errorLog <<
"train_(ClassificationData &labelledTrainingData) - Failed to train model for class: " << classLabel << endl;
172 if( models[k].N == 0 ){
173 errorLog <<
"train_(ClassificationData &labelledTrainingData) - N == 0!" << endl;
177 for(UINT j=0; j<numInputDimensions; j++){
178 if( models[k].sigma[j] == 0 ){
179 errorLog <<
"train_(ClassificationData &labelledTrainingData) - The standard deviation of column " << j+1 <<
" is zero! Check the training data" << endl;
191 nullRejectionThresholds.resize(numClasses);
192 for(UINT k=0; k<numClasses; k++) {
193 nullRejectionThresholds[k] = models[k].threshold;
201 bool ANBC::predict_(VectorDouble &inputVector){
204 errorLog <<
"predict_(VectorDouble &inputVector) - ANBC Model Not Trained!" << endl;
208 predictedClassLabel = 0;
209 maxLikelihood = -10000;
211 if( !trained )
return false;
213 if( inputVector.size() != numInputDimensions ){
214 errorLog <<
"predict_(VectorDouble &inputVector) - The size of the input vector (" << inputVector.size() <<
") does not match the num features in the model (" << numInputDimensions << endl;
219 for(UINT n=0; n<numInputDimensions; n++){
220 inputVector[n] = scale(inputVector[n], ranges[n].minValue, ranges[n].maxValue, MIN_SCALE_VALUE, MAX_SCALE_VALUE);
224 if( classLikelihoods.size() != numClasses ) classLikelihoods.resize(numClasses,0);
225 if( classDistances.size() != numClasses ) classDistances.resize(numClasses,0);
227 double classLikelihoodsSum = 0;
228 double minDist = -99e+99;
229 for(UINT k=0; k<numClasses; k++){
230 classDistances[k] = models[k].predict( inputVector );
233 classLikelihoods[k] = classDistances[k];
236 if( grt_isinf(classLikelihoods[k]) || grt_isnan(classLikelihoods[k]) ){
237 classLikelihoods[k] = 0;
239 classLikelihoods[k] = exp( classLikelihoods[k] );
240 classLikelihoodsSum += classLikelihoods[k];
243 if( classDistances[k] > minDist ){
244 minDist = classDistances[k];
245 predictedClassLabel = k;
251 if( classLikelihoodsSum == 0 ){
252 predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
258 for(UINT k=0; k<numClasses; k++){
259 classLikelihoods[k] /= classLikelihoodsSum;
261 maxLikelihood = classLikelihoods[predictedClassLabel];
263 if( useNullRejection ){
265 if( minDist >= models[predictedClassLabel].threshold ) predictedClassLabel = models[predictedClassLabel].classLabel;
266 else predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
267 }
else predictedClassLabel = models[predictedClassLabel].classLabel;
272 bool ANBC::recomputeNullRejectionThresholds(){
275 if( nullRejectionThresholds.size() != numClasses )
276 nullRejectionThresholds.resize(numClasses);
277 for(UINT k=0; k<numClasses; k++) {
278 models[k].recomputeThresholdValue(nullRejectionCoeff);
279 nullRejectionThresholds[k] = models[k].threshold;
302 bool ANBC::saveModelToFile(fstream &file)
const{
306 errorLog <<
"saveModelToFile(fstream &file) - The file is not open!" << endl;
311 file<<
"GRT_ANBC_MODEL_FILE_V2.0\n";
314 if( !Classifier::saveBaseSettingsToFile(file) ){
315 errorLog <<
"saveModelToFile(fstream &file) - Failed to save classifier base settings to file!" << endl;
321 for(UINT k=0; k<numClasses; k++){
322 file<<
"*************_MODEL_*************\n";
323 file<<
"Model_ID: "<<k+1<<endl;
324 file<<
"N: "<<models[k].N<<endl;
325 file<<
"ClassLabel: "<<models[k].classLabel<<endl;
326 file<<
"Threshold: "<<models[k].threshold<<endl;
327 file<<
"Gamma: "<<models[k].gamma<<endl;
328 file<<
"TrainingMu: "<<models[k].trainingMu<<endl;
329 file<<
"TrainingSigma: "<<models[k].trainingSigma<<endl;
332 for(UINT j=0; j<models[k].N; j++){
333 file <<
"\t" << models[k].mu[j];
337 for(UINT j=0; j<models[k].N; j++){
338 file <<
"\t" << models[k].sigma[j];
342 for(UINT j=0; j<models[k].N; j++){
343 file <<
"\t" << models[k].weights[j];
351 bool ANBC::loadModelFromFile(fstream &file){
354 numInputDimensions = 0;
361 errorLog <<
"loadModelFromFile(string filename) - Could not open file to load model" << endl;
369 if( word ==
"GRT_ANBC_MODEL_FILE_V1.0" ){
370 return loadLegacyModelFromFile( file );
374 if(word !=
"GRT_ANBC_MODEL_FILE_V2.0"){
375 errorLog <<
"loadModelFromFile(string filename) - Could not find Model File Header" << endl;
380 if( !Classifier::loadBaseSettingsFromFile(file) ){
381 errorLog <<
"loadModelFromFile(string filename) - Failed to load base settings from file!" << endl;
388 models.resize(numClasses);
391 for(UINT k=0; k<numClasses; k++){
394 if(word !=
"*************_MODEL_*************"){
395 errorLog <<
"loadModelFromFile(string filename) - Could not find header for the "<<k+1<<
"th model" << endl;
400 if(word !=
"Model_ID:"){
401 errorLog <<
"loadModelFromFile(string filename) - Could not find model ID for the "<<k+1<<
"th model" << endl;
407 cout<<
"ANBC: Model ID does not match the current class ID for the "<<k+1<<
"th model" << endl;
413 cout<<
"ANBC: Could not find N for the "<<k+1<<
"th model" << endl;
419 if(word !=
"ClassLabel:"){
420 errorLog <<
"loadModelFromFile(string filename) - Could not find ClassLabel for the "<<k+1<<
"th model" << endl;
423 file >> models[k].classLabel;
424 classLabels[k] = models[k].classLabel;
427 if(word !=
"Threshold:"){
428 errorLog <<
"loadModelFromFile(string filename) - Could not find the threshold for the "<<k+1<<
"th model" << endl;
431 file >> models[k].threshold;
434 if(word !=
"Gamma:"){
435 errorLog <<
"loadModelFromFile(string filename) - Could not find the gamma parameter for the "<<k+1<<
"th model" << endl;
438 file >> models[k].gamma;
441 if(word !=
"TrainingMu:"){
442 errorLog <<
"loadModelFromFile(string filename) - Could not find the training mu parameter for the "<<k+1<<
"th model" << endl;
445 file >> models[k].trainingMu;
448 if(word !=
"TrainingSigma:"){
449 errorLog <<
"loadModelFromFile(string filename) - Could not find the training sigma parameter for the "<<k+1<<
"th model" << endl;
452 file >> models[k].trainingSigma;
455 models[k].mu.resize(numInputDimensions);
456 models[k].sigma.resize(numInputDimensions);
457 models[k].weights.resize(numInputDimensions);
462 errorLog <<
"loadModelFromFile(string filename) - Could not find the Mu vector for the "<<k+1<<
"th model" << endl;
467 for(UINT j=0; j<models[k].N; j++){
470 models[k].mu[j] = value;
474 if(word !=
"Sigma:"){
475 errorLog <<
"loadModelFromFile(string filename) - Could not find the Sigma vector for the "<<k+1<<
"th model" << endl;
480 for(UINT j=0; j<models[k].N; j++){
483 models[k].sigma[j] = value;
487 if(word !=
"Weights:"){
488 errorLog <<
"loadModelFromFile(string filename) - Could not find the Weights vector for the "<<k+1<<
"th model" << endl;
493 for(UINT j=0; j<models[k].N; j++){
496 models[k].weights[j] = value;
501 recomputeNullRejectionThresholds();
504 maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
505 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
506 classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
507 classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
513 VectorDouble ANBC::getNullRejectionThresholds()
const{
514 if( !trained )
return VectorDouble();
515 return nullRejectionThresholds;
518 bool ANBC::setNullRejectionCoeff(
double nullRejectionCoeff){
520 if( nullRejectionCoeff > 0 ){
521 this->nullRejectionCoeff = nullRejectionCoeff;
522 recomputeNullRejectionThresholds();
531 weightsDataSet =
true;
532 this->weightsData = weightsData;
538 bool ANBC::loadLegacyModelFromFile( fstream &file ){
543 if(word !=
"NumFeatures:"){
544 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find NumFeatures " << endl;
547 file >> numInputDimensions;
550 if(word !=
"NumClasses:"){
551 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find NumClasses" << endl;
557 if(word !=
"UseScaling:"){
558 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find UseScaling" << endl;
564 if(word !=
"UseNullRejection:"){
565 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find UseNullRejection" << endl;
568 file >> useNullRejection;
573 ranges.resize(numInputDimensions);
576 if(word !=
"Ranges:"){
577 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Ranges" << endl;
580 for(UINT n=0; n<ranges.size(); n++){
581 file >> ranges[n].minValue;
582 file >> ranges[n].maxValue;
587 models.resize(numClasses);
588 classLabels.resize(numClasses);
591 for(UINT k=0; k<numClasses; k++){
594 if(word !=
"*************_MODEL_*************"){
595 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find header for the "<<k+1<<
"th model" << endl;
600 if(word !=
"Model_ID:"){
601 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find model ID for the "<<k+1<<
"th model" << endl;
607 cout<<
"ANBC: Model ID does not match the current class ID for the "<<k+1<<
"th model" << endl;
613 cout<<
"ANBC: Could not find N for the "<<k+1<<
"th model" << endl;
619 if(word !=
"ClassLabel:"){
620 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find ClassLabel for the "<<k+1<<
"th model" << endl;
623 file >> models[k].classLabel;
624 classLabels[k] = models[k].classLabel;
627 if(word !=
"Threshold:"){
628 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the threshold for the "<<k+1<<
"th model" << endl;
631 file >> models[k].threshold;
634 if(word !=
"Gamma:"){
635 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the gamma parameter for the "<<k+1<<
"th model" << endl;
638 file >> models[k].gamma;
641 if(word !=
"TrainingMu:"){
642 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the training mu parameter for the "<<k+1<<
"th model" << endl;
645 file >> models[k].trainingMu;
648 if(word !=
"TrainingSigma:"){
649 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the training sigma parameter for the "<<k+1<<
"th model" << endl;
652 file >> models[k].trainingSigma;
655 models[k].mu.resize(numInputDimensions);
656 models[k].sigma.resize(numInputDimensions);
657 models[k].weights.resize(numInputDimensions);
662 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Mu vector for the "<<k+1<<
"th model" << endl;
667 for(UINT j=0; j<models[k].N; j++){
670 models[k].mu[j] = value;
674 if(word !=
"Sigma:"){
675 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Sigma vector for the "<<k+1<<
"th model" << endl;
680 for(UINT j=0; j<models[k].N; j++){
683 models[k].sigma[j] = value;
687 if(word !=
"Weights:"){
688 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the Weights vector for the "<<k+1<<
"th model" << endl;
693 for(UINT j=0; j<models[k].N; j++){
696 models[k].weights[j] = value;
700 if(word !=
"*********************************"){
701 errorLog <<
"loadANBCModelFromFile(string filename) - Could not find the model footer for the "<<k+1<<
"th model" << endl;
710 recomputeNullRejectionThresholds();
713 maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
714 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
715 classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
716 classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
UINT getNumDimensions() const
UINT getNumSamples() const
vector< ClassTracker > getClassTracker() const
UINT getNumClasses() const
This class implements the Adaptive Naive Bayes Classifier algorithm. The Adaptive Naive Bayes Classif...
bool scale(const double minTarget, const double maxTarget)
vector< MinMax > getRanges() const
string getClassifierType() const
ClassificationData getClassData(const UINT classLabel) const