GestureRecognitionToolkit  Version: 1.0 Revision: 04-03-15
The Gesture Recognition Toolkit (GRT) is a cross-platform, open-source, c++ machine learning library for real-time gesture recognition.
BernoulliRBM.h
Go to the documentation of this file.
1 
33 #ifndef GRT_BERNOULLI_RBM_HEADER
34 #define GRT_BERNOULLI_RBM_HEADER
35 
36 #include "../../Util/MatrixDouble.h"
37 #include "../../CoreModules/MLBase.h"
38 
39 namespace GRT{
40 
41 class BernoulliRBM : public MLBase{
42 
43 public:
44  BernoulliRBM(const UINT numHiddenUnits = 100,const UINT maxNumEpochs = 1000,const double learningRate = 1,const double learningRateUpdate = 1,const double momentum = 0.5,const bool useScaling = true,const bool randomiseTrainingOrder = true);
45 
46  virtual ~BernoulliRBM();
47 
56  bool predict_(VectorDouble &inputData);
57 
67  bool predict_(VectorDouble &inputData,VectorDouble &outputData);
68 
79  bool predict_(const MatrixDouble &inputData,MatrixDouble &outputData,const UINT rowIndex);
80 
87  virtual bool train_(MatrixDouble &data);
88 
95  virtual bool reset();
96 
102  virtual bool clear();
103 
110  virtual bool saveModelToFile(fstream &file) const;
111 
118  virtual bool loadModelFromFile(fstream &file);
119 
120  bool reconstruct(const VectorDouble &input,VectorDouble &output);
121 
122  virtual bool print() const;
123 
124  bool getRandomizeWeightsForTraining() const;
125  UINT getNumVisibleUnits() const;
126  UINT getNumHiddenUnits() const;
127  VectorDouble getOutputData() const;
128  const MatrixDouble& getWeights() const;
129 
130  bool setNumHiddenUnits(const UINT numHiddenUnits);
131  bool setMomentum(const double momentum);
132  bool setLearningRateUpdate(const double learningRateUpdate);
133  bool setRandomizeWeightsForTraining(const bool randomizeWeightsForTraining);
134  bool setBatchSize(const UINT batchSize);
135  bool setBatchStepSize(const UINT batchStepSize);
136 
137  //Tell the compiler we are using the base class train method to stop hidden virtual function warnings
140  using MLBase::train;
141  using MLBase::predict;
142  using MLBase::train_;
143  using MLBase::predict_;
144 
145 protected:
146  bool loadLegacyModelFromFile(fstream &file);
147 
148  inline double sigmoid(const double &x) {
149  return 1.0 / (1.0 + exp(-x));
150  }
151 
152  inline double sigmoidRandom(const double &x){
153  return (1.0 / (1.0 + exp(-x)) > rand.getRandomNumberUniform(0.0,1.0)) ? 1.0 : 0.0;
154  }
155 
156  bool randomizeWeightsForTraining;
157  UINT numVisibleUnits;
158  UINT numHiddenUnits;
159  UINT batchSize;
160  UINT batchStepSize;
161  double momentum;
162  double learningRateUpdate;
163  MatrixDouble weightsMatrix;
164  VectorDouble visibleLayerBias;
165  VectorDouble hiddenLayerBias;
166  VectorDouble ph_mean;
167  VectorDouble ph_sample;
168  VectorDouble nv_means;
169  VectorDouble nv_samples;
170  VectorDouble nh_means;
171  VectorDouble nh_samples;
172  VectorDouble outputData;
173  vector<MinMax> ranges;
174  Random rand;
175 
176  struct BatchIndexs{
177  UINT startIndex;
178  UINT endIndex;
179  UINT batchSize;
180  };
181  typedef struct BatchIndexs BatchIndexs;
182 
183 };
184 
185 } //End of namespace GRT
186 
187 #endif //GRT_BERNOULLI_RBM_HEADER
188 
virtual bool train_(MatrixDouble &data)
virtual bool saveModelToFile(string filename) const
Definition: MLBase.cpp:135
virtual bool clear()
virtual bool print() const
virtual bool loadModelFromFile(string filename)
Definition: MLBase.cpp:157
Definition: AdaBoost.cpp:25
virtual bool train(ClassificationData trainingData)
Definition: MLBase.cpp:80
bool predict_(VectorDouble &inputData)
virtual bool predict(VectorDouble inputVector)
Definition: MLBase.cpp:104
virtual bool predict_(VectorDouble &inputVector)
Definition: MLBase.cpp:106
double getRandomNumberUniform(double minRange=0.0, double maxRange=1.0)
Definition: Random.h:197
virtual bool reset()
virtual bool saveModelToFile(fstream &file) const
virtual bool loadModelFromFile(fstream &file)
bool loadLegacyModelFromFile(fstream &file)
virtual bool train_(ClassificationData &trainingData)
Definition: MLBase.cpp:82