TRF Language Model
trf-sa-train.h
Go to the documentation of this file.
1 // You may obtain a copy of the License at
2 //
3 // http://www.apache.org/licenses/LICENSE-2.0
4 //
5 // Unless required by applicable law or agreed to in writing, software
6 // distributed under the License is distributed on an "AS IS" BASIS,
7 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 // See the License for the specific language governing permissions and
9 // limitations under the License.
10 //
11 // Copyright 2014-2015 Tsinghua University
12 // Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13 //
14 // All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15 // license declaration. Different coding language may use different comment styles.
16 
17 
18 #pragma once
19 #include "trf-ml-train.h"
20 #include <omp.h>
21 
22 
23 namespace trf
24 {
32  class SAfunc;
33  class SAtrain;
34 
35  /* the AIS configurations */
36  class AISConfig
37  {
38  public:
39  int nChain;
40  int nInter;
41  AISConfig(int p_nChain = 16, int p_nInter = 1000) :nChain(p_nChain), nInter(p_nInter){}
42  void Parse(const char* str) {
43  String tempStr(str);
44  char *p = strtok(tempStr.GetBuffer(), ":,");
45  nChain = atoi(p);
46  p = strtok(NULL, ":,");
47  nInter = atoi(p);
48  }
49  };
50 
51  /* Save the last sequence of each length in each thread */
52  class ThreadData
53  {
54  public:
56  public:
57  ~ThreadData();
58  void Create(int maxlen, Model *pModel);
59  };
60 
61 
62  /*
63  * \class
64  * \brief augment SA training algorithms
65  */
66  class SAfunc : public MLfunc
67  {
68  friend class SAtrain;
69  protected:
71 
73 
77 
78 // Vec<double> m_vEmpiricalExp; ///< the empirical expectation
79 // Vec<double> m_vEmpiricalExp2; ///< the empirical expectation E[f^2]
83 
84 
85  //Array<ThreadData*> m_threadData; ///< save the last sequence of each threads
87 
88 // Mat<double> m_matEmpiricalExp; ///< the empirical expectation of each thread
89 // Mat<double> m_matEmpiricalExp2; ///< empirical E[f^2] of each thread
93 
95 
96  public:
97  double m_fRegL2;
98  double m_var_gap;
99 
104 
105  public:
117 
118 
119  public:
120  SAfunc() :m_nMiniBatchSample(100) {
121  m_var_gap = 1e-15;
122  m_fRegL2 = 0;
123  };
124  SAfunc(Model *pModel, CorpusBase *pTrain, CorpusBase *pValid = NULL, CorpusBase *pTest = NULL, int nMinibatch = 100 )
125  {
126  m_var_gap = 1e-15;
127  m_fRegL2 = 0;
128  Reset(pModel, pTrain, pValid, pTest, nMinibatch);
129  }
131  {
132 #ifndef _CD
133  for (int i = 0; i < m_threadSeq.GetNum(); i++) {
134  SAFE_DELETE(m_threadSeq[i]);
135  }
136 #endif
137  }
139  virtual void Reset(Model *pModel, CorpusBase *pTrain, CorpusBase *pValid = NULL, CorpusBase *pTest = NULL, int nMinibatch = 100);
141  void PrintInfo();
143  int GetFeatNum() const { return m_pModel->GetParamNum(); }
145  int GetZetaNum() const { return m_pModel->GetMaxLen() + 1; }
147  void RandSeq(Seq &seq, int nLen = -1);
149  void GetParam(double *pdParams);
151  void GetEmpVar(CorpusBase *pCorpus, Vec<double> &vVar);
152 
154  virtual void GetSampleExp(VecShell<double> &vExp, VecShell<double> &vExp2, VecShell<double> &vLen);
155 
157  void IterEnd(double *pFinalParams);
159  void WriteModel(int nEpoch);
160 
161  virtual void SetParam(double *pdParams);
162  virtual void GetGradient(double *pdGradient);
163  virtual double GetValue() { return 0; }
164  virtual int GetExtraValues(int t, double *pdValues);
165  };
166 
167  /*
168  * \class
169  * \brief Learning rate
170  */
172  {
173  public:
174  double beta;
175  double tc;
176  double t0;
177  public:
178  LearningRate() :beta(1), tc(0), t0(0) {}
179  void Reset(const char *pstr, int p_t0);
181  double Get(int t);
182  };
183 
184  /*
185  * \class
186  * \brief SAtraining
187  */
188  class SAtrain : public Solve
189  {
190  protected:
192  double m_gamma_zeta;
193 
194  public:
197  //double m_zeta_upgap; ///< the gap for zeta update
198 
201  double m_dir_gap;
202 
203  int m_nAvgBeg;
204  float m_fEpochNun;
207 
208 #ifdef _Adam
209  double adam_beta1;
210  double adam_beta2;
211  double adam_sigma;
212  double adam_alpha;
213  Vec<double> adam_m;
214  Vec<double> adam_v;
215 #endif
216 
217 #ifdef _Hession
218  Vec<double> m_avgHes;
219 #endif
220 
221  public:
222  SAtrain(SAfunc *pfunc = NULL) : Solve(pfunc)
223  {
224  m_pAlgorithmName = "[SAMS]";
225 
226  m_gamma_lambda = 1;
227  m_gamma_zeta = 1;
228  //m_zeta_upgap = 10;
229 
230 
231  m_bUpdate_lambda = true;
232  m_bUpdate_zeta = true;
233  m_dir_gap = 1.0;
234 
235  m_nAvgBeg = 0;
236 
237  m_fEpochNun = 0;
238  m_nPrintPerIter = 1;
239 #ifdef _Hession
240  m_avgHes.Reset(pfunc->GetFeatNum());
241  m_avgHes.Fill(0);
242 #endif
243 #ifdef _Adam
244  adam_beta1 = 0.9;
245  adam_beta2 = 0.999;
246  adam_alpha = 1e-3;
247  adam_sigma = 1e-8;
248  adam_m.Reset(pfunc->GetFeatNum());
249  adam_v.Reset(pfunc->GetFeatNum());
250  adam_m.Fill(0);
251  adam_v.Fill(0);
252 #endif
253  }
255  virtual bool Run(const double *pInitParams = NULL);
257  void UpdateGamma(int nIterNum);
259  void UpdateDir(double *pDir, double *pGradient, const double *pParam);
261  virtual void Update(double *pdParam, const double *pdDir, double dStep);
263  void PrintInfo();
264  };
265 
267 }
int m_nAvgBeg
if >0, then calculate the average
Definition: trf-sa-train.h:203
int nInter
intermediate distribution number
Definition: trf-sa-train.h:40
Vec< double > m_vSampleExp2
the sample expectation^2
Definition: trf-sa-train.h:81
Vec< double > m_vAllSampleLenCount
the count of each length in all samples
Definition: trf-sa-train.h:74
Vec< Prob > m_samplePi
the length distribution used for sample
Definition: trf-sa-train.h:72
#define SAFE_DELETE(p)
memory release
Definition: wb-vector.h:49
double m_gamma_zeta
Definition: trf-sa-train.h:192
Vec< double > m_vCurSampleLenCount
the count of length in samples of current iteration
Definition: trf-sa-train.h:75
a dynamic string class
Definition: wb-string.h:53
Mat< double > m_matSampleLen
the length count of sample of each thread
Definition: trf-sa-train.h:92
int GetFeatNum() const
get the ngram feature number
Definition: trf-sa-train.h:143
the base class of all the solve classes, and provide a gradient descent algorithm.
Definition: wb-solve.h:77
int nChain
chain number
Definition: trf-sa-train.h:39
Mat< double > m_matSampleExp
the sample expectation of each thread
Definition: trf-sa-train.h:90
double m_fRegL2
l2 regularization
Definition: trf-sa-train.h:97
File m_fparm
output the parameters of each iteration
Definition: trf-sa-train.h:107
Vec< double > m_vEmpiricalVar
empirical variance
Definition: trf-sa-train.h:94
int m_nCDSampleTimes
the CD-n: the sample number.
Definition: trf-sa-train.h:102
int m_nTotalSample
the total sample number
Definition: trf-sa-train.h:76
File m_fmean
output the p[f] on training set
Definition: trf-sa-train.h:109
File m_fgrad
output the gradient of each iteration
Definition: trf-sa-train.h:108
double m_gamma_lambda
Definition: trf-sa-train.h:191
AISConfig(int p_nChain=16, int p_nInter=1000)
Definition: trf-sa-train.h:41
void Parse(const char *str)
Definition: trf-sa-train.h:42
double m_var_gap
a varicance gap used in gradient sacling
Definition: trf-sa-train.h:98
Mat< double > m_matSampleExp2
the sample expectation^2 of each thread
Definition: trf-sa-train.h:91
int GetZetaNum() const
get the zeta parameter number
Definition: trf-sa-train.h:145
define a sequence including the word sequence and class sequence
Definition: trf-feature.h:41
AISConfig m_AISConfigForZ
the AIS configuration for normalization
Definition: trf-sa-train.h:100
File m_fvallidLL
output loglikelihood on valid set
Definition: trf-sa-train.h:115
File m_ftrainLL
output loglikelihood on training set
Definition: trf-sa-train.h:114
SAfunc(Model *pModel, CorpusBase *pTrain, CorpusBase *pValid=NULL, CorpusBase *pTest=NULL, int nMinibatch=100)
Definition: trf-sa-train.h:124
File m_ftrain
output all the training sequences
Definition: trf-sa-train.h:113
wb::Array< int > m_aWriteAtIter
output temp model at some iteration
Definition: trf-sa-train.h:206
File m_fvar
output the variance at each iteration
Definition: trf-sa-train.h:110
TRF model.
Definition: trf-model.h:51
void Fill(T v)
Definition: wb-mat.h:279
file class.
Definition: wb-file.h:94
int m_nMiniBatchSample
mini-batch for samples
Definition: trf-sa-train.h:70
Vec< double > m_vSampleLen
the sample length expectation
Definition: trf-sa-train.h:82
pFunc Run(bInit)
virtual double GetValue()
calculate the function value f(x)
Definition: trf-sa-train.h:163
File m_fdbg
output the sample pi/zete information
Definition: trf-sa-train.h:106
bool m_bUpdate_lambda
Definition: trf-sa-train.h:199
File m_fexp
output the expectation of each iteartion
Definition: trf-sa-train.h:111
int GetNum() const
Get Array number.
Definition: wb-vector.h:240
bool m_bUpdate_zeta
Definition: trf-sa-train.h:200
LearningRate m_gain_lambda
Definition: trf-sa-train.h:195
Array< Seq * > aSeqs
Definition: trf-sa-train.h:55
int m_nSASampleTimes
the SA sample times
Definition: trf-sa-train.h:103
void Reset(int size=0)
Definition: wb-mat.h:360
Array< Seq * > m_threadSeq
save the last sequence of each threads
Definition: trf-sa-train.h:86
double m_dir_gap
control the dir values
Definition: trf-sa-train.h:201
char * GetBuffer() const
get buffer
Definition: wb-string.h:74
int m_nPrintPerIter
output the LL per iteration, if ==-1, the disable
Definition: trf-sa-train.h:205
SAtrain(SAfunc *pfunc=NULL)
Definition: trf-sa-train.h:222
LearningRate m_gain_zeta
Definition: trf-sa-train.h:196
float m_fEpochNun
the current epoch number - double
Definition: trf-sa-train.h:204
AISConfig m_AISConfigForP
the AIS configuration for calculating the LL.
Definition: trf-sa-train.h:101
Vec< double > m_vSampleExp
the sample expectation
Definition: trf-sa-train.h:80
File m_ftestLL
output loglikelihood on test set
Definition: trf-sa-train.h:116
Definition: trf-alg.cpp:20
Dynamic array.
Definition: wb-vector.h:205
File m_fsamp
output all the samples
Definition: trf-sa-train.h:112