28 RegisterClassifierModule< KNN > KNN::registerModule(
"KNN");
30 KNN::KNN(
unsigned int K,
bool useScaling,
bool useNullRejection,
double nullRejectionCoeff,
bool searchForBestKValue,UINT minKSearchValue,UINT maxKSearchValue){
32 this->distanceMethod = EUCLIDEAN_DISTANCE;
33 this->useScaling = useScaling;
34 this->useNullRejection = useNullRejection;
35 this->nullRejectionCoeff = nullRejectionCoeff;
36 this->searchForBestKValue = searchForBestKValue;
37 this->minKSearchValue = minKSearchValue;
38 this->maxKSearchValue = maxKSearchValue;
39 supportsNullRejection =
true;
41 classifierType = classType;
42 classifierMode = STANDARD_CLASSIFIER_MODE;
43 distanceMethod = EUCLIDEAN_DISTANCE;
44 debugLog.setProceedingText(
"[DEBUG KNN]");
45 errorLog.setProceedingText(
"[ERROR KNN]");
46 trainingLog.setProceedingText(
"[TRAINING KNN]");
47 warningLog.setProceedingText(
"[WARNING KNN]");
50 KNN::KNN(
const KNN &rhs){
52 classifierType = classType;
53 classifierMode = STANDARD_CLASSIFIER_MODE;
54 debugLog.setProceedingText(
"[DEBUG KNN]");
55 errorLog.setProceedingText(
"[ERROR KNN]");
56 trainingLog.setProceedingText(
"[TRAINING KNN]");
57 warningLog.setProceedingText(
"[WARNING KNN]");
85 if( classifier == NULL )
return false;
89 KNN *ptr = (
KNN*)classifier;
101 return copyBaseVariables( classifier );
112 errorLog <<
"train_(ClassificationData &trainingData) - Training data has zero samples!" << endl;
120 trainingData.
scale(0, 1);
128 this->trainingData = trainingData;
131 classLabels.resize( numClasses );
132 for(UINT k=0; k<numClasses; k++){
137 if( !searchForBestKValue ){
138 return train_(trainingData,K);
143 double bestAccuracy = 0;
144 vector< IndexedDouble > trainingAccuracyLog;
146 for(UINT k=minKSearchValue; k<=maxKSearchValue; k++){
151 if( !train_(trainingSet, k) ){
152 errorLog <<
"Failed to train model for a k value of " << k << endl;
159 VectorDouble sample = testSet[i].getSample();
161 if( !predict( sample , k) ){
162 errorLog <<
"Failed to predict label for test sample with a k value of " << k << endl;
166 if( testSet[i].getClassLabel() == predictedClassLabel ){
171 accuracy = accuracy /double( testSet.
getNumSamples() ) * 100.0;
174 trainingLog <<
"K:\t" << k <<
"\tAccuracy:\t" << accuracy << endl;
176 if( accuracy > bestAccuracy ){
177 bestAccuracy = accuracy;
185 if( bestAccuracy > 0 ){
187 std::sort(trainingAccuracyLog.begin(),trainingAccuracyLog.end(),IndexedDouble::sortIndexedDoubleByValueDescending);
190 vector< IndexedDouble > tempLog;
193 tempLog.push_back( trainingAccuracyLog[0] );
196 for(UINT i=1; i<trainingAccuracyLog.size(); i++){
197 if( trainingAccuracyLog[i].value == tempLog[0].value ){
198 tempLog.push_back( trainingAccuracyLog[i] );
203 std::sort(tempLog.begin(),tempLog.end(),IndexedDouble::sortIndexedDoubleByIndexAscending);
205 trainingLog <<
"Best K Value: " << tempLog[0].index <<
"\tAccuracy:\t" << tempLog[0].value << endl;
209 return train_(trainingData,tempLog[0].index);
224 if( useNullRejection ){
227 useNullRejection =
false;
228 nullRejectionThresholds.clear();
231 VectorDouble counter(numClasses,0);
232 trainingMu.resize( numClasses, 0 );
233 trainingSigma.resize( numClasses, 0 );
234 nullRejectionThresholds.resize( numClasses, 0 );
237 const unsigned int numTrainingExamples = trainingData.
getNumSamples();
238 vector< IndexedDouble > predictionResults( numTrainingExamples );
239 for(UINT i=0; i<numTrainingExamples; i++){
240 predict( trainingData[i].getSample(), K);
242 UINT classLabelIndex = 0;
243 for(UINT k=0; k<numClasses; k++){
244 if( predictedClassLabel == classLabels[k] ){
250 predictionResults[ i ].index = classLabelIndex;
251 predictionResults[ i ].value = classDistances[ classLabelIndex ];
253 trainingMu[ classLabelIndex ] += predictionResults[ i ].value;
254 counter[ classLabelIndex ]++;
257 for(UINT j=0; j<numClasses; j++){
258 trainingMu[j] /= counter[j];
262 for(UINT i=0; i<numTrainingExamples; i++){
263 trainingSigma[predictionResults[i].index] += SQR(predictionResults[i].value - trainingMu[predictionResults[i].index]);
266 for(UINT j=0; j<numClasses; j++){
267 double count = counter[j];
269 trainingSigma[ j ] = sqrt( trainingSigma[j] / (count-1) );
271 trainingSigma[ j ] = 1.0;
276 bool errorFound =
false;
277 for(UINT j=0; j<numClasses; j++){
278 if( trainingMu[j] == 0 ){
279 warningLog <<
"TrainingMu[ " << j <<
" ] is zero for a K value of " << K << endl;
281 if( trainingSigma[j] == 0 ){
282 warningLog <<
"TrainingSigma[ " << j <<
" ] is zero for a K value of " << K << endl;
284 if( grt_isnan( trainingMu[j] ) ){
285 errorLog <<
"TrainingMu[ " << j <<
" ] is NAN for a K value of " << K << endl;
288 if( grt_isnan( trainingSigma[j] ) ){
289 errorLog <<
"TrainingSigma[ " << j <<
" ] is NAN for a K value of " << K << endl;
300 for(
unsigned int j=0; j<numClasses; j++){
301 nullRejectionThresholds[j] = trainingMu[j] + (trainingSigma[j]*nullRejectionCoeff);
305 useNullRejection =
true;
309 nullRejectionThresholds.clear();
310 nullRejectionThresholds.resize( numClasses, 0 );
316 bool KNN::predict_(VectorDouble &inputVector){
319 errorLog <<
"predict_(VectorDouble &inputVector) - KNN model has not been trained" << endl;
323 if( inputVector.size() != numInputDimensions ){
324 errorLog <<
"predict_(VectorDouble &inputVector) - the size of the input vector " << inputVector.size() <<
" does not match the number of features " << numInputDimensions << endl;
330 for(UINT i=0; i<numInputDimensions; i++){
331 inputVector[i] = scale(inputVector[i], ranges[i].minValue, ranges[i].maxValue, 0, 1);
336 return predict(inputVector,K);
339 bool KNN::predict(
const VectorDouble &inputVector,
const UINT K){
342 errorLog <<
"predict(VectorDouble inputVector,UINT K) - KNN model has not been trained" << endl;
346 if( inputVector.size() != numInputDimensions ){
347 errorLog <<
"predict(VectorDouble inputVector) - the size of the input vector " << inputVector.size() <<
" does not match the number of features " << numInputDimensions << endl;
352 errorLog <<
"predict(VectorDouble inputVector,UINT K) - K Is Greater Than The Number Of Training Samples" << endl;
358 vector< IndexedDouble > neighbours;
360 for(UINT i=0; i<M; i++){
362 UINT classLabel = trainingData[i].getClassLabel();
363 VectorDouble trainingSample = trainingData[i].getSample();
365 switch( distanceMethod ){
366 case EUCLIDEAN_DISTANCE:
367 dist = computeEuclideanDistance(inputVector,trainingSample);
369 case COSINE_DISTANCE:
370 dist = computeCosineDistance(inputVector,trainingSample);
372 case MANHATTAN_DISTANCE:
373 dist = computeManhattanDistance(inputVector, trainingSample);
376 errorLog <<
"predict(vector< double > inputVector) - unkown distance measure!" << endl;
381 if( neighbours.size() < K ){
382 neighbours.push_back( IndexedDouble(classLabel,dist) );
385 double maxValue = neighbours[0].value;
387 for(UINT n=1; n<neighbours.size(); n++){
388 if( neighbours[n].value > maxValue ){
389 maxValue = neighbours[n].value;
395 if( dist < maxValue ){
396 neighbours[ maxIndex ] = IndexedDouble(classLabel,dist);
402 if( classLikelihoods.size() != numClasses ) classLikelihoods.resize(numClasses);
403 if( classDistances.size() != numClasses ) classDistances.resize(numClasses);
405 std::fill(classLikelihoods.begin(),classLikelihoods.end(),0);
406 std::fill(classDistances.begin(),classDistances.end(),0);
409 for(UINT k=0; k<neighbours.size(); k++){
410 UINT classLabel = neighbours[k].index;
411 if( classLabel == 0 ){
412 errorLog <<
"predict(VectorDouble inputVector) - Class label of training example can not be zero!" << endl;
417 UINT classLabelIndex = 0;
418 for(UINT j=0; j<numClasses; j++){
419 if( classLabel == classLabels[j] ){
424 classLikelihoods[ classLabelIndex ] += 1;
425 classDistances[ classLabelIndex ] += neighbours[k].value;
429 double maxCount = classLikelihoods[0];
431 for(UINT i=1; i<classLikelihoods.size(); i++){
432 if( classLikelihoods[i] > maxCount ){
433 maxCount = classLikelihoods[i];
439 for(UINT i=0; i<numClasses; i++){
440 if( classLikelihoods[i] > 0 ) classDistances[i] /= classLikelihoods[i];
441 else classDistances[i] = BIG_DISTANCE;
445 for(UINT i=0; i<numClasses; i++){
446 classLikelihoods[i] /= double( neighbours.size() );
450 maxLikelihood = classLikelihoods[ maxIndex ];
452 if( useNullRejection ){
453 if( classDistances[ maxIndex ] <= nullRejectionThresholds[ maxIndex ] ){
454 predictedClassLabel = classLabels[maxIndex];
456 predictedClassLabel = GRT_DEFAULT_NULL_CLASS_LABEL;
459 predictedClassLabel = classLabels[maxIndex];
471 trainingData.
clear();
473 trainingSigma.clear();
478 bool KNN::saveModelToFile(fstream &file)
const{
482 errorLog <<
"saveModelToFile(fstream &file) - Could not open file to save model!" << endl;
487 file <<
"GRT_KNN_MODEL_FILE_V2.0\n";
490 if( !Classifier::saveBaseSettingsToFile(file) ){
491 errorLog <<
"saveModelToFile(fstream &file) - Failed to save classifier base settings to file!" << endl;
495 file <<
"K: " << K << endl;
496 file <<
"DistanceMethod: " << distanceMethod << endl;
497 file <<
"SearchForBestKValue: " << searchForBestKValue << endl;
498 file <<
"MinKSearchValue: " << minKSearchValue << endl;
499 file <<
"MaxKSearchValue: " << maxKSearchValue << endl;
502 if( useNullRejection ){
503 file <<
"TrainingMu: ";
504 for(UINT j=0; j<trainingMu.size(); j++){
505 file << trainingMu[j] <<
"\t";
508 file <<
"TrainingSigma: ";
509 for(UINT j=0; j<trainingSigma.size(); j++){
510 file << trainingSigma[j] <<
"\t";
514 file <<
"NumTrainingSamples: " << trainingData.
getNumSamples() << endl;
515 file <<
"TrainingData: \n";
519 file<< trainingData[i].getClassLabel() <<
"\t";
521 for(UINT j=0; j<numInputDimensions; j++){
522 file << trainingData[i][j] <<
"\t";
531 bool KNN::loadModelFromFile(fstream &file){
535 errorLog <<
"loadModelFromFile(fstream &file) - Could not open file to load model!" << endl;
544 if( word ==
"GRT_KNN_MODEL_FILE_V1.0" ){
545 return loadLegacyModelFromFile( file );
549 if(word !=
"GRT_KNN_MODEL_FILE_V2.0"){
550 errorLog <<
"loadModelFromFile(fstream &file) - Could not find Model File Header!" << endl;
555 if( !Classifier::loadBaseSettingsFromFile(file) ){
556 errorLog <<
"loadModelFromFile(string filename) - Failed to load base settings from file!" << endl;
562 errorLog <<
"loadModelFromFile(fstream &file) - Could not find K!" << endl;
568 if(word !=
"DistanceMethod:"){
569 errorLog <<
"loadModelFromFile(fstream &file) - Could not find DistanceMethod!" << endl;
572 file >> distanceMethod;
575 if(word !=
"SearchForBestKValue:"){
576 errorLog <<
"loadModelFromFile(fstream &file) - Could not find SearchForBestKValue!" << endl;
579 file >> searchForBestKValue;
582 if(word !=
"MinKSearchValue:"){
583 errorLog <<
"loadModelFromFile(fstream &file) - Could not find MinKSearchValue!" << endl;
586 file >> minKSearchValue;
589 if(word !=
"MaxKSearchValue:"){
590 errorLog <<
"loadModelFromFile(fstream &file) - Could not find MaxKSearchValue!" << endl;
593 file >> maxKSearchValue;
598 trainingMu.resize(numClasses,0);
599 trainingSigma.resize(numClasses,0);
601 if( useNullRejection ){
603 if(word !=
"TrainingMu:"){
604 errorLog <<
"loadModelFromFile(fstream &file) - Could not find TrainingMu!" << endl;
609 for(UINT j=0; j<numClasses; j++){
610 file >> trainingMu[j];
614 if(word !=
"TrainingSigma:"){
615 errorLog <<
"loadModelFromFile(fstream &file) - Could not find TrainingSigma!" << endl;
620 for(UINT j=0; j<numClasses; j++){
621 file >> trainingSigma[j];
626 if(word !=
"NumTrainingSamples:"){
627 errorLog <<
"loadModelFromFile(fstream &file) - Could not find NumTrainingSamples!" << endl;
630 unsigned int numTrainingSamples = 0;
631 file >> numTrainingSamples;
634 if(word !=
"TrainingData:"){
635 errorLog <<
"loadModelFromFile(fstream &file) - Could not find TrainingData!" << endl;
641 unsigned int classLabel = 0;
642 vector< double > sample(numInputDimensions,0);
643 for(UINT i=0; i<numTrainingSamples; i++){
648 for(UINT j=0; j<numInputDimensions; j++){
653 trainingData.
addSample(classLabel, sample);
656 maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
657 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
658 classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
659 classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
665 bool KNN::recomputeNullRejectionThresholds(){
671 nullRejectionThresholds.resize(numClasses,0);
673 if( trainingMu.size() != numClasses || trainingSigma.size() != numClasses ){
677 for(
unsigned int j=0; j<numClasses; j++){
678 nullRejectionThresholds[j] = trainingMu[j] + (trainingSigma[j]*nullRejectionCoeff);
684 bool KNN::setK(UINT K){
692 bool KNN::setMinKSearchValue(UINT minKSearchValue){
693 this->minKSearchValue = minKSearchValue;
697 bool KNN::setMaxKSearchValue(UINT maxKSearchValue){
698 this->maxKSearchValue = maxKSearchValue;
702 bool KNN::enableBestKValueSearch(
bool searchForBestKValue){
703 this->searchForBestKValue = searchForBestKValue;
707 bool KNN::setNullRejectionCoeff(
double nullRejectionCoeff){
708 if( nullRejectionCoeff > 0 ){
709 this->nullRejectionCoeff = nullRejectionCoeff;
710 recomputeNullRejectionThresholds();
716 bool KNN::setDistanceMethod(UINT distanceMethod){
717 if( distanceMethod == EUCLIDEAN_DISTANCE || distanceMethod == COSINE_DISTANCE || distanceMethod == MANHATTAN_DISTANCE ){
718 this->distanceMethod = distanceMethod;
724 double KNN::computeEuclideanDistance(
const VectorDouble &a,
const VectorDouble &b){
726 for(UINT j=0; j<numInputDimensions; j++){
727 dist += SQR( a[j] - b[j] );
732 double KNN::computeCosineDistance(
const VectorDouble &a,
const VectorDouble &b){
739 for(UINT j=0; j<numInputDimensions; j++){
740 dotAB += a[j] * b[j];
745 dist = dotAB / (sqrt(magA) * sqrt(magB));
750 double KNN::computeManhattanDistance(
const VectorDouble &a,
const VectorDouble &b){
753 for(UINT j=0; j<numInputDimensions; j++){
754 dist += fabs( a[j] - b[j] );
760 bool KNN::loadLegacyModelFromFile( fstream &file ){
766 if(word !=
"NumFeatures:"){
767 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find NumFeatures!" << endl;
770 file >> numInputDimensions;
773 if(word !=
"NumClasses:"){
774 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find NumClasses!" << endl;
781 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find K!" << endl;
787 if(word !=
"DistanceMethod:"){
788 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find DistanceMethod!" << endl;
791 file >> distanceMethod;
794 if(word !=
"SearchForBestKValue:"){
795 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find SearchForBestKValue!" << endl;
798 file >> searchForBestKValue;
801 if(word !=
"MinKSearchValue:"){
802 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find MinKSearchValue!" << endl;
805 file >> minKSearchValue;
808 if(word !=
"MaxKSearchValue:"){
809 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find MaxKSearchValue!" << endl;
812 file >> maxKSearchValue;
815 if(word !=
"UseScaling:"){
816 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find UseScaling!" << endl;
822 if(word !=
"UseNullRejection:"){
823 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find UseNullRejection!" << endl;
826 file >> useNullRejection;
829 if(word !=
"NullRejectionCoeff:"){
830 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find NullRejectionCoeff!" << endl;
833 file >> nullRejectionCoeff;
838 ranges.resize( numInputDimensions );
841 if(word !=
"Ranges:"){
842 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find Ranges!" << endl;
843 cout <<
"Word: " << word << endl;
846 for(UINT n=0; n<ranges.size(); n++){
847 file >> ranges[n].minValue;
848 file >> ranges[n].maxValue;
853 trainingMu.resize(numClasses,0);
854 trainingSigma.resize(numClasses,0);
857 if(word !=
"TrainingMu:"){
858 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find TrainingMu!" << endl;
863 for(UINT j=0; j<numClasses; j++){
864 file >> trainingMu[j];
868 if(word !=
"TrainingSigma:"){
869 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find TrainingSigma!" << endl;
874 for(UINT j=0; j<numClasses; j++){
875 file >> trainingSigma[j];
879 if(word !=
"NumTrainingSamples:"){
880 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find NumTrainingSamples!" << endl;
883 unsigned int numTrainingSamples = 0;
884 file >> numTrainingSamples;
887 if(word !=
"TrainingData:"){
888 errorLog <<
"loadLegacyModelFromFile(fstream &file) - Could not find TrainingData!" << endl;
894 unsigned int classLabel = 0;
895 vector< double > sample(numInputDimensions,0);
896 for(UINT i=0; i<numTrainingSamples; i++){
901 for(UINT j=0; j<numInputDimensions; j++){
906 trainingData.
addSample(classLabel, sample);
913 recomputeNullRejectionThresholds();
UINT maxKSearchValue
The minimum K value to start the search from
bool searchForBestKValue
The distance method used to compute the distance between each data point
This class implements the K-Nearest Neighbor classification algorithm (http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm). KNN is a simple but powerful classifier, based on finding the closest K training examples in the feature space for the new input vector. The KNN algorithm is amongst the simplest of all machine learning algorithms: an object is classified by a majority vote of its neighbors, with the object being assigned to the class most common amongst its k nearest neighbors (k is a positive integer, typically small). If k = 1, then the object is simply assigned to the class of its nearest neighbor.
UINT getNumDimensions() const
bool setNumDimensions(UINT numDimensions)
UINT getNumSamples() const
vector< ClassTracker > getClassTracker() const
UINT getNumClasses() const
bool scale(const double minTarget, const double maxTarget)
VectorDouble trainingMu
Holds the trainingData to perform the predictions
VectorDouble trainingSigma
Holds the average max-class distance of the training data for each of classes
vector< MinMax > getRanges() const
ClassificationData trainingData
The maximum K value to end the search at
string getClassifierType() const
ClassificationData partition(const UINT partitionPercentage, const bool useStratifiedSampling=false)
UINT minKSearchValue
Sets if the best K value should be searched for or if the model should be trained with K ...
bool addSample(UINT classLabel, const VectorDouble &sample)
UINT distanceMethod
The number of neighbours to search for