TRF Language Model
hrf-sa-train.h
Go to the documentation of this file.
1 #pragma once
2 #include "hrf-ml-train.h"
3 #include "hrf-corpus.h"
4 #include <omp.h>
5 
6 
7 namespace hrf
8 {
9  class SAfunc;
10  class SAtrain;
11 
12  /* the AIS configurations */
13  class AISConfig
14  {
15  public:
16  int nChain;
17  int nInter;
18  AISConfig(int p_nChain = 16, int p_nInter = 1000) :nChain(p_nChain), nInter(p_nInter){}
19  void Parse(const char* str) {
20  String tempStr(str);
21  char *p = strtok(tempStr.GetBuffer(), ":,");
22  nChain = atoi(p);
23  p = strtok(NULL, ":,");
24  nInter = atoi(p);
25  }
26  };
27 
28  /* Save the last sequence of each length in each thread */
29  class ThreadData
30  {
31  public:
33  public:
34  ~ThreadData();
35  void Create(int maxlen, Model *pModel);
36  };
37 
38  /*
39  * \class
40  * \brief augment SA training algorithms
41  */
42  class SAfunc : public MLfunc
43  {
44  friend class SAtrain;
45  protected:
49  CorpusCache m_TrainCache;
50 
52 
53  private:
54  Vec<double> m_vAllSampleLenCount;
55  Vec<double> m_vCurSampleLenCount;
56  int m_nTotalSample;
57 
58  Vec<double> m_vEmpFeatExp;
59  Vec<double> m_vEmpFeatVar;
60 
61  Vec<double> m_vEmpExp;
62  Vec<double> m_vEmpExp2;
63  Vec<double> m_vSampleExp;
64  Vec<double> m_vSampleLen;
65 
66 #ifndef _CD
67  Array<ThreadData*> m_threadData;
68  Array<Seq*> m_aSeqs;
69 #endif
70  //Array<trf::RandSeq<int>*> m_trainSelectPerLen; ///< save the index of training sequences of each length
71 
72 
73  Mat<double> m_matEmpiricalExp;
74  Mat<double> m_matEmpiricalExp2;
75  Mat<double> m_matSampleExp;
76  Mat<double> m_matSampleLen;
77 
78  /* for emprical variance estimation :*/
79 #ifdef _Var
80  Vec<double> m_vExpValue;
81  Vec<double> m_vExp2Value;
82  public:
83  double m_var_gap;
84 #endif
85  //Vec<double> m_vEmpiricalVar; ///< current empirical variance E[f^2]-E[f]^2
86 
87  public:
95 
96  public:
106 
110 
111  public:
112  SAfunc() :m_nMiniBatchSample(100), m_nMiniBatchTraining(100) {
113 #ifdef _Var
114  m_var_gap = 1e-5;
115 #endif
116  m_bPrintTrain = true;
117  m_bPrintValie = true;
118  m_bPrintTest = true;
119 
120  m_bSAMSSample = false;
121  };
122  SAfunc(Model *pModel, CorpusBase *pTrain, CorpusBase *pValid = NULL, CorpusBase *pTest = NULL, int nMinibatch = 100 )
123  {
124  Reset(pModel, pTrain, pValid, pTest, nMinibatch);
125 #ifdef _Var
126  m_var_gap = 1e-5;
127 #endif
128  m_bPrintTrain = true;
129  m_bPrintValie = true;
130  m_bPrintTest = true;
131 
132  m_bSAMSSample = false;
133  }
135  {
136 
137 #ifndef _CD
138  for (int i = 0; i < m_aSeqs.GetNum(); i++)
139  SAFE_DELETE(m_aSeqs[i]);
140  for (int i = 0; i < m_threadData.GetNum(); i++) {
141  SAFE_DELETE(m_threadData[i]);
142  }
143 #endif
144 // for (int i = 0; i < m_trainSelectPerLen.GetNum(); i++) {
145 // SAFE_DELETE(m_trainSelectPerLen[i]);
146 // }
147  }
149  void Reset(Model *pModel, CorpusBase *pTrain, CorpusBase *pValid = NULL, CorpusBase *pTest = NULL, int nMinibatch = 100);
151  void PrintInfo();
153  int GetNgramFeatNum() const { return m_pModel->m_pFeat->GetNum(); }
155  int GetVHmatSize() const { return m_pModel->m_m3dVH.GetSize(); }
157  int GetCHmatSize() const { return m_pModel->m_m3dCH.GetSize(); }
159  int GetHHmatSize() const { return m_pModel->m_m3dHH.GetSize(); }
161 // int GetBiasSize() const { return m_pModel->m_matBias.GetSize(); }
163  int GetWeightNum() const { return m_pModel->GetParamNum(); }
165  int GetZetaNum() const { return m_pModel->GetMaxLen() + 1; }
167  void RandSeq(Seq &seq, int nLen = -1);
169  void GetParam(double *pdParams);
171  //void GetFeatEmpVar(CorpusBase *pCorpus, Vec<double> &vVar);
172 
174  void GetEmpiricalFeatExp(Vec<double> &vExp);
176  void GetEmpiricalFeatVar(Vec<double> &vVar);
177 
179  int GetEmpiricalExp(VecShell<double> &vExp, VecShell<double> &vExp2, Array<int> &aRandIdx);
181  int GetEmpiricalExp(VecShell<double> &vExp, VecShell<double> &vExp2);
183  int GetSampleExp(VecShell<double> &vExp, VecShell<double> &vLen);
185  void PerfromCD(VecShell<double> &vEmpExp, VecShell<double> &vSamExp, VecShell<double> &vEmpExp2, VecShell<double> &vLen);
187  void PerfromSA(VecShell<double> &vEmpExp, VecShell<double> &vSamExp, VecShell<double> &vEmpExp2, VecShell<double> &vLen);
189 // void PerfromSAMS(VecShell<double> &vEmpExp, VecShell<double> &vSamExp, VecShell<double> &vEmpExp2, VecShell<double> &vLen);
191  /* method = 0 : AIS, =1 : Chib */
192  double GetSampleLL(CorpusBase *pCorpus, int nCalNum = -1, int method = 0);
194  void IterEnd(double *pFinalParams);
196  void WriteModel(int nEpoch);
197 
198  virtual void SetParam(double *pdParams);
199  virtual void GetGradient(double *pdGradient);
200  virtual double GetValue() { return 0; }
201  virtual int GetExtraValues(int t, double *pdValues);
202  };
203 
204  /*
205  * \class
206  * \brief Learning rate
207  */
209  {
210  public:
211  double beta;
212  double tc;
213  double t0;
214  public:
215  LearningRate() :beta(1), tc(0), t0(0) {}
216  void Reset(const char *pstr, int p_t0);
218  double Get(int t);
219  };
220 
221  /*
222  * \class
223  * \brief SAtraining
224  */
225  class SAtrain : public Solve
226  {
227  protected:
230  double m_gamma_zeta;
231 
232  public:
233 
237 
240 
241  double m_dir_gap;
242  double m_zeta_gap;
243 
244  float m_fMomentum;
245  int m_nAvgBeg;
246 
247  double m_fEpochNum;
248 
251 
252 #ifdef _Var
253  //double m_var_threshold;
254  double m_gamma_var;
255  LearningRate m_gain_var;
256 #endif
257  public:
258  SAtrain(SAfunc *pfunc = NULL) : Solve(pfunc)
259  {
260 #ifndef _CD
261  m_pAlgorithmName = "[SA]";
262 #else
263  m_pAlgorithmName = "[CD]";
264 #endif
265 
266  m_gamma_lambda = 1;
267  m_gamma_hidden = 1;
268  m_gamma_zeta = 1;
269 
270 
271  m_bUpdate_lambda = true;
272  m_bUpdate_zeta = true;
273 
274  m_dir_gap = 0;
275  m_zeta_gap = 0;
276 
277  m_fMomentum = 0;
278  m_nAvgBeg = 0;
279 
280  m_fEpochNum = 0;
281  m_nPrintPerIter = 1;
282 // #ifdef _Var
283 // m_var_threshold = 1e-4;
284 // #endif
285  }
287  virtual bool Run(const double *pInitParams = NULL);
289  void UpdateGamma(int nIterNum);
291  void UpdateDir(double *pDir, double *pGradient, const double *pParam);
293  virtual void Update(double *pdParam, const double *pdDir, double dStep);
295  void PrintInfo();
297  int CutValue(double *p, int num, double gap);
298  };
299 
300 
301 }
double m_gamma_zeta
Definition: hrf-sa-train.h:230
#define SAFE_DELETE(p)
memory release
Definition: wb-vector.h:49
int nChain
chain number
Definition: hrf-sa-train.h:16
a dynamic string class
Definition: wb-string.h:53
trf::CorpusRandSelect m_TrainSelect
random select the sequence from corpus
Definition: hrf-sa-train.h:48
the base class of all the solve classes, and provide a gradient descent algorithm.
Definition: wb-solve.h:77
wb::Array< int > m_aWriteAtIter
output temp model at some iteration
Definition: hrf-sa-train.h:250
Array< Seq * > aSeqs
Definition: hrf-sa-train.h:32
bool m_bPrintTrain
output the LL on training set
Definition: hrf-sa-train.h:107
bool m_bSAMSSample
if using the sams sampling method
Definition: hrf-sa-train.h:94
LearningRate m_gain_zeta
Definition: hrf-sa-train.h:236
bool m_bPrintValie
output the LL on valid set
Definition: hrf-sa-train.h:108
File m_fdbg
output the sample pi/zete information
Definition: hrf-sa-train.h:97
virtual double GetValue()
calculate the function value f(x)
Definition: hrf-sa-train.h:200
Vec< Prob > m_samplePi
the length distribution used for sample
Definition: hrf-sa-train.h:51
int m_nMiniBatchTraining
mini-batch for training set
Definition: hrf-sa-train.h:47
hidden-random-field model
Definition: hrf-model.h:98
SAtrain(SAfunc *pfunc=NULL)
Definition: hrf-sa-train.h:258
int nInter
intermediate distribution number
Definition: hrf-sa-train.h:17
AISConfig m_AISConfigForZ
the AIS configuration for normalization
Definition: hrf-sa-train.h:88
int m_nAvgBeg
if >0, then calculate the average
Definition: hrf-sa-train.h:245
LearningRate m_gain_lambda
Definition: hrf-sa-train.h:234
int GetWeightNum() const
get the bias mat number
Definition: hrf-sa-train.h:163
void Parse(const char *str)
Definition: hrf-sa-train.h:19
bool m_bUpdate_zeta
Definition: hrf-sa-train.h:239
double m_zeta_gap
Definition: hrf-sa-train.h:242
double m_gamma_lambda
Definition: hrf-sa-train.h:228
File m_fgrad
output the gradient of each iteration
Definition: hrf-sa-train.h:99
int GetZetaNum() const
get the zeta parameter number
Definition: hrf-sa-train.h:165
int m_nTrainHiddenSampleTimes
the sample times for training sequence
Definition: hrf-sa-train.h:90
File m_fexp
output the expectation of each iteartion
Definition: hrf-sa-train.h:101
bool m_bPrintTest
output the LL on test set
Definition: hrf-sa-train.h:109
int m_nCDSampleTimes
the CD-n: the sample number.
Definition: hrf-sa-train.h:92
bool m_bUpdate_lambda
Definition: hrf-sa-train.h:238
SAfunc(Model *pModel, CorpusBase *pTrain, CorpusBase *pValid=NULL, CorpusBase *pTest=NULL, int nMinibatch=100)
Definition: hrf-sa-train.h:122
file class.
Definition: wb-file.h:94
int GetVHmatSize() const
get the VH mat number
Definition: hrf-sa-train.h:155
File m_fvar
output the variance at each iteration
Definition: hrf-sa-train.h:100
AISConfig m_AISConfigForP
the AIS configuration for calculating the LL.
Definition: hrf-sa-train.h:89
File m_feat_mean
output the empirical mean
Definition: hrf-sa-train.h:104
int GetCHmatSize() const
get the CH mat number
Definition: hrf-sa-train.h:157
double m_gamma_hidden
Definition: hrf-sa-train.h:229
float m_fMomentum
the momentum
Definition: hrf-sa-train.h:244
pFunc Run(bInit)
int m_nSampleHiddenSampleTimes
the sample times for the hidden of samples
Definition: hrf-sa-train.h:91
double m_dir_gap
Definition: hrf-sa-train.h:241
int GetHHmatSize() const
get the HH mat number
Definition: hrf-sa-train.h:159
File m_feat_var
output the empirical variance
Definition: hrf-sa-train.h:105
File m_fsamp
output all the samples
Definition: hrf-sa-train.h:102
int m_nPrintPerIter
output the LL per iteration, if ==-1, the disable
Definition: hrf-sa-train.h:249
int GetNum() const
Get Array number.
Definition: wb-vector.h:240
LearningRate m_gain_hidden
Definition: hrf-sa-train.h:235
File m_fparm
output the parameters of each iteration
Definition: hrf-sa-train.h:98
double m_fEpochNum
the current epoch number
Definition: hrf-sa-train.h:247
char * GetBuffer() const
get buffer
Definition: wb-string.h:74
CorpusCache m_TrainCache
cache all the h of training sequences.
Definition: hrf-sa-train.h:49
int m_nMiniBatchSample
mini-batch for samples
Definition: hrf-sa-train.h:46
File m_ftrain
output all the training sequences
Definition: hrf-sa-train.h:103
int m_nSASampleTimes
the SA sample times
Definition: hrf-sa-train.h:93
int GetNgramFeatNum() const
get the ngram feature number
Definition: hrf-sa-train.h:153
Dynamic array.
Definition: wb-vector.h:205
AISConfig(int p_nChain=16, int p_nInter=1000)
Definition: hrf-sa-train.h:18