25 LDA::LDA(
bool useScaling,
bool useNullRejection,
double nullRejectionCoeff)
27 this->useScaling = useScaling;
28 this->useNullRejection = useNullRejection;
29 this->nullRejectionCoeff = nullRejectionCoeff;
31 classifierType = classType;
32 classifierMode = STANDARD_CLASSIFIER_MODE;
33 debugLog.setProceedingText(
"[DEBUG LDA]");
34 errorLog.setProceedingText(
"[ERROR LDA]");
35 trainingLog.setProceedingText(
"[TRAINING LDA]");
36 warningLog.setProceedingText(
"[WARNING LDA]");
45 errorLog <<
"SORRY - this module is still under development and can't be used yet!" << endl;
49 numInputDimensions = 0;
56 errorLog <<
"train(LabelledClassificationData trainingData) - There is no training data to train the model!" << endl;
64 MatrixDouble SB = computeBetweenClassScatterMatrix( trainingData );
67 MatrixDouble SW = computeWithinClassScatterMatrix( trainingData );
202 errorLog <<
"predict(vector< double > inputVector) - LDA Model Not Trained!" << endl;
206 predictedClassLabel = 0;
207 maxLikelihood = -10000;
209 if( !trained )
return false;
211 if( inputVector.size() != numInputDimensions ){
212 errorLog <<
"predict(vector< double > inputVector) - The size of the input vector (" << inputVector.size() <<
") does not match the num features in the model (" << numInputDimensions << endl;
217 if( classLikelihoods.size() != numClasses || classDistances.size() != numClasses ){
218 classLikelihoods.resize(numClasses);
219 classDistances.resize(numClasses);
227 for(UINT k=0; k<numClasses; k++){
229 for(UINT j=0; j<numInputDimensions+1; j++){
230 if( j==0 ) classDistances[k] = models[k].weights[j];
231 else classDistances[k] += inputVector[j-1] * models[k].weights[j];
233 classLikelihoods[k] = exp( classDistances[k] );
234 sum += classLikelihoods[k];
236 if( classLikelihoods[k] > maxLikelihood ){
238 maxLikelihood = classLikelihoods[k];
243 for(UINT k=0; k<numClasses; k++){
244 classLikelihoods[k] /= sum;
247 maxLikelihood = classLikelihoods[ bestIndex ];
249 predictedClassLabel = models[ bestIndex ].classLabel;
258 errorLog <<
"saveModelToFile(fstream &file) - Could not open file to save model" << endl;
263 file<<
"GRT_LDA_MODEL_FILE_V1.0\n";
264 file<<
"NumFeatures: "<<numInputDimensions<<endl;
265 file<<
"NumClasses: "<<numClasses<<endl;
266 file <<
"UseScaling: " << useScaling << endl;
267 file<<
"UseNullRejection: " << useNullRejection << endl;
271 file <<
"Ranges: \n";
272 for(UINT n=0; n<ranges.size(); n++){
273 file << ranges[n].minValue <<
"\t" << ranges[n].maxValue << endl;
278 for(UINT k=0; k<numClasses; k++){
279 file<<
"ClassLabel: "<<models[k].classLabel<<endl;
280 file<<
"PriorProbability: "<<models[k].priorProb<<endl;
283 for(UINT j=0; j<models[k].getNumDimensions(); j++){
284 file <<
"\t" << models[k].weights[j];
294 numInputDimensions = 0;
301 errorLog <<
"loadModelFromFile(fstream &file) - The file is not open!" << endl;
309 if(word !=
"GRT_LDA_MODEL_FILE_V1.0"){
310 errorLog <<
"loadModelFromFile(fstream &file) - Could not find Model File Header" << endl;
315 if(word !=
"NumFeatures:"){
316 errorLog <<
"loadModelFromFile(fstream &file) - Could not find NumFeatures " << endl;
319 file >> numInputDimensions;
322 if(word !=
"NumClasses:"){
323 errorLog <<
"loadModelFromFile(fstream &file) - Could not find NumClasses" << endl;
329 if(word !=
"UseScaling:"){
330 errorLog <<
"loadModelFromFile(fstream &file) - Could not find UseScaling" << endl;
336 if(word !=
"UseNullRejection:"){
337 errorLog <<
"loadModelFromFile(fstream &file) - Could not find UseNullRejection" << endl;
340 file >> useNullRejection;
345 ranges.resize(numInputDimensions);
348 if(word !=
"Ranges:"){
349 errorLog <<
"loadModelFromFile(fstream &file) - Could not find the Ranges" << endl;
352 for(UINT n=0; n<ranges.size(); n++){
353 file >> ranges[n].minValue;
354 file >> ranges[n].maxValue;
359 models.resize(numClasses);
360 classLabels.resize(numClasses);
363 for(UINT k=0; k<numClasses; k++){
365 if(word !=
"ClassLabel:"){
366 errorLog <<
"loadModelFromFile(fstream &file) - Could not find ClassLabel for the "<<k+1<<
"th model" << endl;
369 file >> models[k].classLabel;
370 classLabels[k] = models[k].classLabel;
373 if(word !=
"PriorProbability:"){
374 errorLog <<
"loadModelFromFile(fstream &file) - Could not find the PriorProbability for the "<<k+1<<
"th model" << endl;
377 file >> models[k].priorProb;
379 models[k].weights.resize(numInputDimensions+1);
383 if(word !=
"Weights:"){
384 errorLog <<
"loadModelFromFile(fstream &file) - Could not find the Weights vector for the "<<k+1<<
"th model" << endl;
389 for(UINT j=0; j<numInputDimensions+1; j++){
392 models[k].weights[j] = value;
397 maxLikelihood = DEFAULT_NULL_LIKELIHOOD_VALUE;
398 bestDistance = DEFAULT_NULL_DISTANCE_VALUE;
399 classLikelihoods.resize(numClasses,DEFAULT_NULL_LIKELIHOOD_VALUE);
400 classDistances.resize(numClasses,DEFAULT_NULL_DISTANCE_VALUE);
412 VectorDouble totalMean = data.
getMean();
413 sb.setAllValues( 0 );
415 for(UINT k=0; k<numClasses; k++){
419 for(UINT m=0; m<numInputDimensions; m++){
420 for(UINT n=0; n<numInputDimensions; n++){
421 sb[m][n] += (classMean[k][m]-totalMean[m]) * (classMean[k][n]-totalMean[n]) * double(numSamplesInClass);
429 MatrixDouble LDA::computeWithinClassScatterMatrix( ClassificationData &data ){
431 MatrixDouble sw(numInputDimensions,numInputDimensions);
432 sw.setAllValues( 0 );
434 for(UINT k=0; k<numClasses; k++){
437 ClassificationData classData = data.getClassData( data.getClassTracker()[k].classLabel );
438 MatrixDouble scatterMatrix = classData.getCovarianceMatrix();
441 for(UINT m=0; m<numInputDimensions; m++){
442 for(UINT n=0; n<numInputDimensions; n++){
443 sw[m][n] += scatterMatrix[m][n];
LDA(bool useScaling=false, bool useNullRejection=true, double nullRejectionCoeff=10.0)
VectorDouble getMean() const
UINT getNumDimensions() const
UINT getNumSamples() const
This class implements the Linear Discriminant Analysis Classification algorithm.
vector< ClassTracker > getClassTracker() const
UINT getNumClasses() const
virtual bool saveModelToFile(fstream &file) const
virtual bool predict(VectorDouble inputVector)
virtual bool loadModelFromFile(fstream &file)
MatrixDouble getClassMean() const
virtual bool train(ClassificationData trainingData)