36 this->numSteps = numSteps;
37 this->positiveClassificationThreshold = positiveClassificationThreshold;
38 this->minAlphaSearchRange = minAlphaSearchRange;
39 this->maxAlphaSearchRange = maxAlphaSearchRange;
45 trainingLog.setProceedingText(
"[DEBUG RadialBasisFunction]");
46 warningLog.setProceedingText(
"[WARNING RadialBasisFunction]");
47 errorLog.setProceedingText(
"[ERROR RadialBasisFunction]");
60 this->numSteps = rhs.numSteps;
61 this->alpha = rhs.alpha;
62 this->gamma = rhs.gamma;
63 this->positiveClassificationThreshold = rhs.positiveClassificationThreshold;
64 this->minAlphaSearchRange = rhs.minAlphaSearchRange;
65 this->maxAlphaSearchRange = rhs.maxAlphaSearchRange;
66 this->rbfCentre = rhs.rbfCentre;
73 if( weakClassifer == NULL )
return false;
91 errorLog <<
"train(ClassificationData &trainingData, VectorDouble &weights) - There should only be 2 classes in the training data, but there are : " << trainingData.
getNumClasses() << endl;
97 errorLog <<
"train(ClassificationData &trainingData, VectorDouble &weights) - There number of examples in the training data (" << trainingData.
getNumSamples() <<
") does not match the lenght of the weights vector (" << weights.size() <<
")" << endl;
106 double maxWeight = 0;
107 vector< UINT > bestWeights;
108 for(UINT i=0; i<M; i++){
109 if( trainingData[i].getClassLabel() == WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL ){
110 if( weights[i] > maxWeight ){
111 maxWeight = weights[i];
113 bestWeights.push_back(i);
114 }
else if( weights[i] == maxWeight ){
115 bestWeights.push_back( i );
121 const UINT N = (UINT)bestWeights.size();
124 errorLog <<
"train(ClassificationData &trainingData, VectorDouble &weights) - There are no positive class weigts!" << endl;
128 for(UINT i=0; i<N; i++){
130 rbfCentre[j] += trainingData[ bestWeights[i] ][j];
136 rbfCentre[j] /= double(N);
140 double step = (maxAlphaSearchRange-minAlphaSearchRange)/numSteps;
141 double bestAlpha = 0;
142 double minError = numeric_limits<double>::max();
144 alpha = minAlphaSearchRange;
145 while( alpha <= maxAlphaSearchRange ){
148 gamma = -1.0/(2.0*SQR(alpha));
152 for(UINT i=0; i<M; i++){
153 bool positiveSample = trainingData[ i ].getClassLabel() == WEAK_CLASSIFIER_POSITIVE_CLASS_LABEL;
154 double v = rbf(trainingData[ i ].getSample(),rbfCentre);
156 if( (v >= positiveClassificationThreshold && !positiveSample) || (v<positiveClassificationThreshold && positiveSample) ){
162 if( error < minError ){
176 gamma = -1.0/(2.0*SQR(alpha));
179 cout <<
"BestAlpha: " << bestAlpha <<
" Error: " << minError << endl;
185 if( rbf(x,rbfCentre) >= positiveClassificationThreshold )
return 1;
189 double RadialBasisFunction::rbf(
const VectorDouble &a,
const VectorDouble &b){
190 const UINT N = (UINT)a.size();
193 for(UINT i=0; i<N; i++){
196 return exp( gamma * r );
203 errorLog <<
"saveModelToFile(fstream &file) - The file is not open!" << endl;
209 file <<
"Trained: "<<
trained << endl;
213 file <<
"NumSteps: " << numSteps << endl;
214 file <<
"PositiveClassificationThreshold: " << positiveClassificationThreshold << endl;
215 file <<
"Alpha: " << alpha << endl;
216 file <<
"MinAlphaSearchRange: " << minAlphaSearchRange << endl;
217 file <<
"MaxAlphaSearchRange: " << maxAlphaSearchRange << endl;
221 file << rbfCentre[i] <<
"\t";
222 }
else file << 0 <<
"\t";
234 errorLog <<
"loadModelFromFile(fstream &file) - The file is not open!" << endl;
241 if( word !=
"WeakClassifierType:" ){
242 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read WeakClassifierType header!" << endl;
248 errorLog <<
"loadModelFromFile(fstream &file) - The weakClassifierType:" << word <<
" does not match: " <<
weakClassifierType << endl;
253 if( word !=
"Trained:" ){
254 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read Trained header!" << endl;
260 if( word !=
"NumInputDimensions:" ){
261 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read NumInputDimensions header!" << endl;
268 if( word !=
"NumSteps:" ){
269 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read NumSteps header!" << endl;
275 if( word !=
"PositiveClassificationThreshold:" ){
276 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read PositiveClassificationThreshold header!" << endl;
279 file >> positiveClassificationThreshold;
282 if( word !=
"Alpha:" ){
283 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read Alpha header!" << endl;
289 if( word !=
"MinAlphaSearchRange:" ){
290 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read MinAlphaSearchRange header!" << endl;
293 file >> minAlphaSearchRange;
296 if( word !=
"MaxAlphaSearchRange:" ){
297 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read MaxAlphaSearchRange header!" << endl;
300 file >> maxAlphaSearchRange;
303 if( word !=
"RBF:" ){
304 errorLog <<
"loadModelFromFile(fstream &file) - Failed to read RBF header!" << endl;
307 rbfCentre.resize(numInputDimensions);
310 file >> rbfCentre[i];
314 gamma = -1.0/(2.0*SQR(alpha));
332 return positiveClassificationThreshold;
340 return minAlphaSearchRange;
344 return maxAlphaSearchRange;
static RegisterWeakClassifierModule< RadialBasisFunction > registerModule
This is used to register the DecisionStump with the WeakClassifier base class.
virtual bool deepCopyFrom(const WeakClassifier *weakClassifer)
UINT numInputDimensions
The number of input dimensions to the weak classifier.
virtual bool train(ClassificationData &trainingData, VectorDouble &weights)
UINT getNumDimensions() const
UINT getNumSamples() const
string weakClassifierType
A string that represents the weak classifier type, e.g. DecisionStump.
UINT getNumClasses() const
virtual ~RadialBasisFunction()
virtual void print() const
VectorDouble getRBFCentre() const
string getWeakClassifierType() const
RadialBasisFunction & operator=(const RadialBasisFunction &rhs)
bool trained
A flag to show if the weak classifier model has been trained.
bool copyBaseVariables(const WeakClassifier *weakClassifer)
RadialBasisFunction(UINT numSteps=100, double positiveClassificationThreshold=0.9, double minAlphaSearchRange=0.001, double maxAlphaSearchRange=1.0)
This class implements a Radial Basis Function Weak Classifier. The Radial Basis Function (RBF) class ...
double getMinAlphaSearchRange() const
double getPositiveClassificationThreshold() const
virtual bool saveModelToFile(fstream &file) const
virtual double predict(const VectorDouble &x)
double getMaxAlphaSearchRange() const
virtual bool loadModelFromFile(fstream &file)