TRF Language Model
trf-sa-train.cpp
Go to the documentation of this file.
1 // You may obtain a copy of the License at
2 //
3 // http://www.apache.org/licenses/LICENSE-2.0
4 //
5 // Unless required by applicable law or agreed to in writing, software
6 // distributed under the License is distributed on an "AS IS" BASIS,
7 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 // See the License for the specific language governing permissions and
9 // limitations under the License.
10 //
11 // Copyright 2014-2015 Tsinghua University
12 // Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13 //
14 // All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15 // license declaration. Different coding language may use different comment styles.
16 
17 
18 #include "trf-sa-train.h"
19 #include "wb-log.h"
20 
21 namespace trf
22 {
24  {
25  for (int i = 0; i < aSeqs.GetNum(); i++) {
26  SAFE_DELETE(aSeqs[i]);
27  }
28  }
29  void ThreadData::Create(int maxlen, Model *pModel)
30  {
31  aSeqs.SetNum(maxlen + 1);
32  aSeqs.Fill(NULL);
33  for (int i = 1; i < aSeqs.GetNum(); i++) {
34  aSeqs[i] = new Seq(i);
35  aSeqs[i]->Random(pModel->m_pVocab);
36  }
37  }
38 
39  void SAfunc::Reset(Model *pModel, CorpusBase *pTrain, CorpusBase *pValid /* = NULL */, CorpusBase *pTest /* = NULL */, int nMinibatch /* = 100 */)
40  {
41  MLfunc::Reset(pModel, pTrain, pValid, pTest);
42  GetEmpVar(pTrain, m_vEmpiricalVar);
43 
44  m_nMiniBatchSample = nMinibatch;
45 
46  /*
47  sampling pi
48  */
49  lout << "Smoothing the pi" << endl;
50  double dMax = 0;
51  int iMax = 0;
52  for (int i = 1; i < m_trainPi.GetSize(); i++) {
53  if (m_trainPi[i] > dMax) {
54  dMax = m_trainPi[i];
55  iMax = i;
56  }
57  }
58  m_samplePi.Copy(m_trainPi);
59  for (int i = 1; i < iMax; i++) {
60  m_samplePi[i] = dMax;
61  }
62  for (int i = 1; i < m_samplePi.GetSize(); i++) {
63  m_samplePi[i] = max((double)m_samplePi[i], 1e-5);
64  }
65  LineNormalize(m_samplePi.GetBuf() + 1, m_samplePi.GetSize() - 1);
66 
67  lout << "sample-pi = [ "; lout.output(m_samplePi.GetBuf() + 1, m_samplePi.GetSize() - 1); lout << "]" << endl;
68  m_pModel->SetPi(m_samplePi.GetBuf());
69 
70  /* save the sample count */
71  m_vAllSampleLenCount.Reset(m_pModel->GetMaxLen() + 1);
72  m_vCurSampleLenCount.Reset(m_pModel->GetMaxLen() + 1);
73  m_vAllSampleLenCount.Fill(0);
74  m_nTotalSample = 0;
75 
76  /* for SA estimateio. there are two set of paremeters
77  i.e. feature weight \lambda and normalization constants \zeta
78  */
79  m_nParamNum = m_pModel->GetParamNum() + m_pModel->GetMaxLen() + 1;
80 
81  m_nCDSampleTimes = 1;
82  m_nSASampleTimes = 1;
83 
84  }
86  {
87  lout << "[SAfunc] *** Info: *** " << endl;
88  lout << " "; lout_variable(m_nMiniBatchSample);
89  lout << " "; lout_variable(m_var_gap);
90  lout << " "; lout_variable(m_fRegL2);
91  lout << "[SAfunc] *** [End] ***" << endl;
92  }
93  void SAfunc::RandSeq(Seq &seq, int nLen /* = -1 */)
94  {
95  if (nLen == -1) {
96  nLen = rand() % m_pModel->GetMaxLen() + 1;
97  }
98 
99  seq.Reset(nLen);
100  seq.Random(m_pModel->m_pVocab);
101  }
102  void SAfunc::SetParam(double *pdParams)
103  {
104  if (pdParams == NULL)
105  return;
106 
107  m_value.Reset(m_pModel->GetParamNum());
108  for (int i = 0; i < m_value.GetSize(); i++)
109  m_value[i] = (PValue)pdParams[i];
110  m_pModel->SetParam(m_value.GetBuf());
111  m_pModel->ExactNormalize(1); // only calculate Z_1
112 
113  /* set zeta */
114  m_pModel->SetZeta(pdParams + m_pModel->GetParamNum());
115 
116  if (m_fparm.Good()) {
117  m_fparm.PrintArray("%f ", pdParams, m_nParamNum);
118  }
119  }
120  void SAfunc::GetParam(double *pdParams)
121  {
122  if (pdParams == NULL)
123  return;
124 
125  /* get lambda */
126  m_value.Reset(m_pModel->GetParamNum());
127  m_pModel->GetParam(m_value.GetBuf());
128  for (int i = 0; i < m_value.GetSize(); i++)
129  pdParams[i] = m_value[i];
130  /* get zeta */
131  pdParams += m_pModel->GetParamNum();
132  for (int i = 0; i <= m_pModel->GetMaxLen(); i++) {
133  pdParams[i] = m_pModel->m_zeta[i];
134  }
135  }
136 
137  int qsort_compare_double(const void * a, const void * b)
138  {
139  if (*(double*)a < *(double*)b) return -1;
140  if (*(double*)a == *(double*)b) return 0;
141  if (*(double*)a > *(double*)b) return 1;
142  }
143 
145  {
146  int nThread = omp_get_max_threads();
147 
148  // the true length distribution
149  Prob *pi = m_trainPi.GetBuf();
150 
151  vVar.Fill(0);
153  Vec<double> vExpf2(m_pModel->GetParamNum());
154  Vec<double> vExp_l(m_pModel->GetParamNum());
155 
156  Mat<double> matExpf2(nThread, vExpf2.GetSize());
157  Mat<double> matExp_l(nThread, vExp_l.GetSize());
158 
159  vExpf2.Fill(0);
160  vExp_l.Fill(0);
161  matExpf2.Fill(0);
162  matExp_l.Fill(0);
163 
165  lout.Progress(0, true, pCorpus->GetNum() - 1, "[SAfunc] E[f^2]:");
166 #pragma omp parallel for firstprivate(aSeq)
167  for (int l = 0; l < pCorpus->GetNum(); l++) {
168 
169  double *pExpf2 = matExpf2[omp_get_thread_num()].GetBuf();
170 
171  pCorpus->GetSeq(l, aSeq);
172  Seq seq;
173  seq.Set(aSeq, m_pModel->m_pVocab);
174 
175  int nLen = min(m_pModel->GetMaxLen(), seq.GetLen());
176 
177  LHash<int, int> aFeatNum;
178  bool bFound;
179  Array<int> afeat;
180  m_pModel->m_pFeat->Find(afeat, seq);
181  for (int i = 0; i < afeat.GetNum(); i++) {
182  int *p = aFeatNum.Insert(afeat[i], bFound);
183  if (!bFound) *p = 0;
184  (*p) += 1;
185  }
186  LHashIter<int, int> iter(&aFeatNum);
187  int *pCount;
188  int nFeat;
189  while (pCount = iter.Next(nFeat)) {
190  pExpf2[nFeat] += pow((double)(*pCount), 2);
191  }
192 #pragma omp critical
193  {
194  lout.Progress();
195  }
196  }
197 
198  vExpf2.Fill(0);
199  for (int t = 0; t < nThread; t++) {
200  vExpf2 += matExpf2[t];
201  }
202  vExpf2 /= pCorpus->GetNum();
203 
204 
205  //lout_variable(aExpFeatSqu[38272]);
206 
209  lout.Progress(0, true, m_pModel->GetMaxLen(), "[SAfunc] E_l[f]:");
210  for (int nLen = 1; nLen <= m_pModel->GetMaxLen(); nLen++)
211  {
212  matExp_l.Fill(0);
213 
214  Array<int> aSeqId;
216  for (int i = 0; i < pCorpus->GetNum(); i++) {
217  pCorpus->GetSeq(i, aSeq);
218  int nSeqLen = aSeq.GetNum();
219  if (nLen == m_pModel->GetMaxLen()) {
220  if (nSeqLen < nLen)
221  continue;
222  }
223  else {
224  if (nSeqLen != nLen)
225  continue;
226  }
227  aSeqId.Add(i);
228  }
229 
230 #pragma omp parallel for firstprivate(aSeq)
231  for (int k = 0; k < aSeqId.GetNum(); k++)
232  {
233  pCorpus->GetSeq(aSeqId[k], aSeq);
234 
235  Seq seq;
236  seq.Set(aSeq, m_pModel->m_pVocab);
237  m_pModel->FeatCount(seq, matExp_l[omp_get_thread_num()].GetBuf());
238  }
239 
240  if (aSeqId.GetNum() > 0) {
241  vExp_l.Fill(0);
242  for (int t = 0; t < nThread; t++) {
243  vExp_l += matExp_l[t];
244  }
245  vExp_l /= aSeqId.GetNum();
246  }
247  else {
248  vExp_l.Fill(0);
249  }
250 
251 
252  for (int i = 0; i < m_pModel->GetParamNum(); i++)
253  vExpf2[i] -= pi[nLen] * pow(vExp_l[i], 2);
254 
255  lout.Progress(nLen);
256  }
257 
259  int nZero = 0;
260  int nDownGap = 0;
261  double dMinVarOverZero = 100;
262  for (int i = 0; i < m_nParamNum; i++) {
263  if (vExpf2[i] == 0)
264  nZero++;
265  else
266  dMinVarOverZero = min(vExpf2[i], dMinVarOverZero);
267 
268  if (vExpf2[i] < m_var_gap) {
269  nDownGap++;
270  vExpf2[i] = m_var_gap;
271  }
272 
273  }
274  if (nZero > 0) {
275  lout_warning("[EmpiricalVar] Exist zero expectation (zero-num=" << nZero << ")");
276  }
277  lout << "[EmpiricalVar] the number of ( var < gap=" << m_var_gap << " ) is " << nDownGap << endl;
278  lout << "[EmpiricalVar] min variance value (over 0) is " << dMinVarOverZero << endl;
279 
280 
282  vVar.Copy(vExpf2);
283 
284  // Write
285  if (m_fmean.Good()) {
286  lout << "Write Empirical Mean ..." << endl;
287  Vec<PValue> aLogExp(m_vEmpiricalExp.GetSize());
288  for (int i = 0; i < aLogExp.GetSize(); i++) aLogExp[i] = log(m_vEmpiricalExp[i]);
289  m_pModel->m_pFeat->WriteT(m_fmean, aLogExp.GetBuf());
290 // m_fmean.PrintArray("%f\n", m_vEmpiricalExp.GetBuf(), m_vEmpiricalExp.GetSize());
291  }
292  if (m_fvar.Good()) {
293  lout << "Write Empirical Var ..." << endl;
294  Vec<PValue> aLogVar(vVar.GetSize());
295  for (int i = 0; i < vVar.GetSize(); i++) aLogVar[i] = log(vVar[i]);
296  m_pModel->m_pFeat->WriteT(m_fvar, aLogVar.GetBuf());
297  //m_fvar.PrintArray("%f\n", vVar.GetBuf(), vVar.GetSize());
298  }
299  }
300 
302  {
303  int nThread = omp_get_max_threads();
304  m_matSampleExp.Reset(nThread, m_pModel->GetParamNum());
305  m_matSampleExp2.Reset(nThread, m_pModel->GetParamNum());
306  m_matSampleLen.Reset(nThread, m_pModel->GetMaxLen() + 1);
307 // Vec<int> vNum(nThread); // record the sample number of each thread
308 
309  m_matSampleExp.Fill(0);
310  m_matSampleLen.Fill(0);
311 // vNum.Fill(0);
312 
313 
314  // init the sequence
315  if (m_threadSeq.GetNum() != nThread) {
316  for (int i = 0; i < nThread; i++) {
317  m_threadSeq[i] = new Seq;
318  RandSeq(*m_threadSeq[i]);
319  }
320  }
321 
322  /* sampling */
323  //lout.Progress(0, true, m_nMiniBatchSample-1, "[SA] sample:");
324 #pragma omp parallel for
325  for (int sample = 0; sample < m_nMiniBatchSample; sample++)
326  {
327  int tid = omp_get_thread_num();
328  Vec<double> aCurCount(m_pModel->GetParamNum());
329  m_pModel->Sample(*m_threadSeq[tid]);
330 
331  int nLen = min(m_pModel->GetMaxLen(), m_threadSeq[tid]->GetLen());
332 
333  m_pModel->FeatCount(*m_threadSeq[tid], aCurCount.GetBuf(), m_trainPi[nLen] / m_pModel->m_pi[nLen]);
334  //m_pModel->FeatCount(*m_threadSeq[tid], m_matSampleExp[tid].GetBuf(), m_trainPi[nLen] / m_pModel->m_pi[nLen]);
335  for (int i = 0; i < aCurCount.GetSize(); i++) {
336  m_matSampleExp[tid][i] += aCurCount[i];
337  m_matSampleExp2[tid][i] += pow(aCurCount[i], 2);
338  }
339  m_matSampleLen[tid][nLen]++;
340 
341 #pragma omp critical
342  {
343  if (m_fsamp.Good()) {
344  m_threadSeq[tid]->Print(m_fsamp);
345  }
346  //lout.Progress();
347  }
348 
349  }
350  lout << " len-jump acc-rate=";
351  lout_variable_rate(m_pModel->m_nLenJumpAccTimes, m_pModel->m_nLenJumpTotalTime);
352  m_pModel->m_nLenJumpAccTimes = 0;
353  m_pModel->m_nLenJumpTotalTime = 0;
354  lout << " class-propose acc-rate=";
355  lout_variable_rate(m_pModel->m_nSampleHAccTimes, m_pModel->m_nSampleHTotalTimes);
356  m_pModel->m_nSampleHAccTimes = 0;
357  m_pModel->m_nSampleHTotalTimes = 0;
358  lout << endl;
359 
360 
361 
362  // summarization
363  vExp.Fill(0);
364  vExp2.Fill(0);
365  vLen.Fill(0);
366  for (int t = 0; t < nThread; t++) {
367  vExp += m_matSampleExp[t];
368  vExp2 += m_matSampleExp2[t];
369  vLen += m_matSampleLen[t];
370  }
371  m_vAllSampleLenCount += vLen;
372  m_vCurSampleLenCount.Copy(vLen);
373  m_nTotalSample += m_nMiniBatchSample;
374 
375  vExp /= m_nMiniBatchSample;
376  vExp2 /= m_nMiniBatchSample;
377  vLen /= m_nMiniBatchSample;
378  }
379 
380  void SAfunc::IterEnd(double *pFinalParams)
381  {
382  SetParam(pFinalParams);
383  // set the pi as the len-prob in training set.
384  m_pModel->SetPi(m_trainPi.GetBuf());
385  }
386  void SAfunc::WriteModel(int nEpoch)
387  {
388  String strTempModel;
389  String strName = String(m_pathOutputModel).FileName();
390 #ifdef __linux
391  strTempModel.Format("%s.n%d.model", strName.GetBuffer(), nEpoch);
392 #else
393  strTempModel.Format("%s.n%d.model", strName.GetBuffer(), nEpoch);
394 #endif
395  // set the pi as the pi of training set
396  m_pModel->SetPi(m_trainPi.GetBuf());
397  m_pModel->WriteT(strTempModel);
398  m_pModel->SetPi(m_samplePi.GetBuf());
399  }
400  void SAfunc::GetGradient(double *pdGradient)
401  {
402  int nWeightNum = m_pModel->GetParamNum();
403  m_vSampleExp.Reset(nWeightNum);
404  m_vSampleExp2.Reset(nWeightNum);
405  m_vSampleLen.Reset(m_pModel->GetMaxLen() + 1);
406 
407 
408 // /* get theoretical expectation */
409  GetSampleExp(m_vSampleExp, m_vSampleExp2, m_vSampleLen);
410 
411 
412 #if defined(_Adam)
413  for (int i = 0; i < nWeightNum; i++) {
414  pdGradient[i] = m_vEmpiricalExp[i] - m_vSampleExp[i]
415  - m_fRegL2 * m_pModel->m_value[i];// the L2 regularization
416  }
417 
418 #elif defined(_Hession)
419  for (int i = 0; i < nWeightNum; i++) {
420  pdGradient[i] = m_vEmpiricalExp[i] - m_vSampleExp[i]
421  - m_fRegL2 * m_pModel->m_value[i];// the L2 regularization
422  }
423 #else
424  /* Calculate the gradient */
425  for (int i = 0; i < nWeightNum; i++) {
426  pdGradient[i] = (
427  m_vEmpiricalExp[i] - m_vSampleExp[i]
428  - m_fRegL2 * m_pModel->m_value[i] // the L2 regularization
429  ) / ( m_vEmpiricalVar[i] + m_fRegL2 ) ; // rescaled by variance
430  }
431 #endif
432 
433 
434 
435  /*
436  Zeta update
437  */
438  for (int l = 0; l <= m_pModel->GetMaxLen(); l++) {
439  if (m_pModel->m_pi[l] > 0) {
440  pdGradient[nWeightNum + l] = m_vSampleLen[l] / m_pModel->m_pi[l];
441  }
442  else {
443  pdGradient[nWeightNum + l] = 0;
444  }
445  }
446 
447 
448  if (m_fgrad.Good()) {
449  m_fgrad.PrintArray("%f ", pdGradient, m_nParamNum);
450  m_fgrad.Print("\n");
451  }
452  if (m_fexp.Good()) {
453  m_fexp.PrintArray("%f ", m_vSampleExp.GetBuf(), m_vSampleExp.GetSize());
454  m_fexp.Print("\n");
455  }
456 
457 
458 
459  }
460  int SAfunc::GetExtraValues(int t, double *pdValues)
461  {
462  int nValue = 0;
463 
464  // set the training pi
465  m_pModel->SetPi(m_trainPi.GetBuf());
466 
467  Vec<Prob> samsZeta(m_pModel->m_zeta.GetSize());
468  Vec<Prob> trueZeta(m_pModel->m_zeta.GetSize());
469  //Vec<double> trueLogZ(m_pModel->m_logz.GetSize());
470  samsZeta.Fill(0);
471  trueZeta.Fill(0);
472  samsZeta = m_pModel->m_zeta;
473 
474  /* calcualte the p(v) */
475  Vec<double> vLL;
476  if (m_pCorpusTrain) {
477  pdValues[nValue++] = -GetLL(m_pCorpusTrain, -1, &vLL);
478  if (m_ftrainLL.Good()) {
479  m_ftrainLL.Reopen("wt");
480  m_ftrainLL.PrintArray("%f\n", vLL.GetBuf(), vLL.GetSize());
481  }
482  }
483  if (m_pCorpusValid) {
484  pdValues[nValue++] = -GetLL(m_pCorpusValid, -1, &vLL);
485  if (m_fvallidLL.Good()){
486  m_fvallidLL.Reopen("wt");
487  m_fvallidLL.PrintArray("%f\n", vLL.GetBuf(), vLL.GetSize());
488  }
489  }
490  if (m_pCorpusTest) {
491  pdValues[nValue++] = -GetLL(m_pCorpusTest, -1, &vLL);
492  if (m_ftestLL.Good()){
493  m_ftestLL.Reopen("wt");
494  m_ftestLL.PrintArray("%f\n", vLL.GetBuf(), vLL.GetSize());
495  }
496  }
497 
498  /* true Z_L to get the LL */
499  if (m_pModel->m_pVocab->GetSize() < 100 && m_pModel->GetMaxOrder() < 4) {
500 
501  m_pModel->ExactNormalize(); // normalization
502  trueZeta.Copy(m_pModel->m_zeta);
503  if (m_pCorpusTrain) pdValues[nValue++] = -GetLL(m_pCorpusTrain);
504  if (m_pCorpusValid) pdValues[nValue++] = -GetLL(m_pCorpusValid);
505  if (m_pCorpusTest) pdValues[nValue++] = -GetLL(m_pCorpusTest);
506 
507  m_pModel->SetZeta(samsZeta.GetBuf());
508  }
509 
510 
511  /* output debug */
512  if (!m_fdbg.Good()) {
513  m_fdbg.Open("SAfunc.dbg", "wt");
514  }
515  m_vAllSampleLenCount *= 1.0 / m_nTotalSample;
516  m_vCurSampleLenCount *= 1.0 / m_nMiniBatchSample;
517  m_fdbg.PrintArray("%f ", m_vCurSampleLenCount.GetBuf() + 1, m_vCurSampleLenCount.GetSize() - 1);
518  m_fdbg.PrintArray("%f ", m_vAllSampleLenCount.GetBuf() + 1, m_vAllSampleLenCount.GetSize() - 1);
519  m_fdbg.PrintArray("%f ", m_samplePi.GetBuf() + 1, m_samplePi.GetSize() - 1);
520  m_fdbg.PrintArray("%f ", trueZeta.GetBuf() + 1, trueZeta.GetSize() - 1);
521  m_fdbg.PrintArray("%f ", samsZeta.GetBuf() + 1, samsZeta.GetSize() - 1);
522  m_fdbg.Print("\n");
523  m_vAllSampleLenCount *= m_nTotalSample;
524  m_vCurSampleLenCount *= m_nMiniBatchSample;
525 
526  m_pModel->SetPi(m_samplePi.GetBuf());
527 
528  return nValue;
529  }
530 
531  void LearningRate::Reset(const char *pstr, int p_t0)
532  {
533  sscanf(pstr, "%lf,%lf", &tc, &beta);
534  t0 = p_t0;
535  //lout << "[Learning Rate] tc=" << tc << " beta=" << beta << " t0=" << t0 << endl;
536  }
537  double LearningRate::Get(int t)
538  {
539  double gamma;
540  if (t <= t0) {
541  gamma = 1.0 / (tc + pow(t, beta));
542  }
543  else {
544  gamma = 1.0 / (tc + pow(t0, beta) + t - t0);
545  }
546  return gamma;
547  }
548 
549 
550  bool SAtrain::Run(const double *pInitParams /* = NULL */)
551  {
552  if (!m_pfunc) {
553  lout_Solve << "m_pFunc == NULL" << endl;
554  return false;
555  }
556  Clock ck;
557  m_dSpendMinute = 0;
558 
559  SAfunc *pSA = (SAfunc*)m_pfunc;
560 // int nIterPerEpoch = pSA->m_pCorpusTrain->GetNum() / pSA->m_nMiniBatchSample + 1;
561 // lout_variable(nIterPerEpoch);
562 
563 
564  double *pdCurParams = new double[m_pfunc->GetParamNum()];
565  double *pdCurGradient = new double[m_pfunc->GetParamNum()];
566  double *pdCurDir = new double[m_pfunc->GetParamNum()]; // current update direction
567  double dCurValue = 0;
568  double dExValues[Func::cn_exvalue_max_num];
569  double nExValueNum;
570 
571  // if average
572  bool bAvg = (m_nAvgBeg > 0);
573  double *pdAvgParams = NULL;
574  if (bAvg) {
575  pdAvgParams = new double[m_pfunc->GetParamNum()];
576  }
577 
578 
579 
580  for (int i = 0; i < m_pfunc->GetParamNum(); i++) {
581  pdCurParams[i] = (pInitParams) ? pInitParams[i] : 1;
582  }
583  memset(pdCurGradient, 0, sizeof(double)*m_pfunc->GetParamNum());
584  memset(pdCurDir, 0, sizeof(double)*m_pfunc->GetParamNum());
585 
586  IterInit();
587  m_pfunc->SetParam(pdCurParams);
588  //pSA->WriteModel(0);
589 
590  // iteration begin
591  lout_Solve << "************* Training Begin *****************" << endl;
592  lout_Solve << "print-per-iter=" << m_nPrintPerIter << endl;
593  lout.bOutputCmd() = false;
594  ck.Begin();
595  for (m_nIterNum = m_nIterMin; m_nIterNum <= m_nIterMax; m_nIterNum++)
596  {
597  // epoch number
598  m_fEpochNun = 1.0 * m_nIterNum * pSA->m_nMiniBatchSample / pSA->m_pCorpusTrain->GetNum();
599 
600  // set the parameter
601  m_pfunc->SetParam(pdCurParams);
602  // get the gradient
603  m_pfunc->GetGradient(pdCurGradient);
604  // get the function value
605  dCurValue = m_pfunc->GetValue();
606  // get the averaged parameters
607  if (bAvg) {
608  if (m_nIterNum <= m_nAvgBeg) {
609  memcpy(pdAvgParams, pdCurParams, sizeof(pdCurParams[0])*m_pfunc->GetParamNum());
610  }
611  else {
612  for (int i = 0; i < m_pfunc->GetParamNum(); i++) {
613  pdAvgParams[i] += (pdCurParams[i] - pdAvgParams[i]) / (m_nIterNum - m_nAvgBeg);
614  }
615  }
616  }
617 
618  // print
619  if (m_nIterNum % m_nPrintPerIter == 0 || m_nIterNum == m_nIterMax)
620  {
621  m_dSpendMinute = ck.ToSecond(ck.Get()) / 60;
622  bool bPrintCmd;
623 
624  bPrintCmd = lout.bOutputCmd();
625  lout.bOutputCmd() = true;
626  lout_Solve << "t=" << m_nIterNum;
627  cout<<setprecision(4)<<setiosflags(ios::fixed);
628  lout << " epoch=" << m_fEpochNun;
629  cout<<setprecision(2)<<setiosflags(ios::fixed);
630  lout << " time=" << m_dSpendMinute << "m";
631  lout << (bAvg ? " [Avg]" : " ");
632  lout.bOutputCmd() = bPrintCmd;
633 
634 
635  // get the ex-value
636  if (bAvg) pSA->SetParam(pdAvgParams);
637  // This function will use AIS to normaliza the model
638  nExValueNum = pSA->GetExtraValues(m_nIterNum, dExValues);
639 
640  bPrintCmd = lout.bOutputCmd();
641  lout.bOutputCmd() = true;
642  lout<< "ExValues={ ";
643  cout<< setprecision(2) << setiosflags(ios::fixed);
644  for (int i = 0; i < nExValueNum; i++)
645  lout << dExValues[i] << " ";
646  lout << "}" << endl;
647 
648  // write model
649  if (m_aWriteAtIter.Find(m_nIterNum) != -1)
650  pSA->WriteModel(m_nIterNum);
651 
652  lout.bOutputCmd() = bPrintCmd;
653 
654  if (bAvg) pSA->SetParam(pdCurParams);
655  }
656  //lout.Progress(m_nIterNum % m_nPrintPerIter, true, m_nPrintPerIter - 1, "Train:");
657 
658 
659 
660 
661  /* Stop Decision */
662  if (StopDecision(m_nIterNum, dCurValue, pdCurGradient)) {
663  break;
664  }
665 
666 
667  // update the learning rate gamma
668  UpdateGamma(m_nIterNum);
669 
670  // update the direction
671  UpdateDir(pdCurDir, pdCurGradient, pdCurParams);
672 
673  // Update parameters
674  Update(pdCurParams, pdCurDir, 0);
675  }
676 
677  lout_Solve << "************* Training End *****************" << endl;
678  lout_Solve << "iter=" << m_nIterNum << " time=" << m_dSpendMinute << "m" << endl;
679  lout_Solve << "********************************************" << endl;
680 
681  // do something at the end of the iteration
682  if (bAvg) pSA->IterEnd(pdAvgParams);
683  else pSA->IterEnd(pdCurParams);
684 
685  SAFE_DELETE_ARRAY(pdCurGradient);
686  SAFE_DELETE_ARRAY(pdCurDir);
687  SAFE_DELETE_ARRAY(pdCurParams);
688  SAFE_DELETE_ARRAY(pdAvgParams);
689  return true;
690  }
691 
692  void SAtrain::UpdateGamma(int nIterNum)
693  {
694  m_gamma_lambda = m_gain_lambda.Get(nIterNum);
695  m_gamma_zeta = m_gain_zeta.Get(nIterNum);
696 
697  lout_Solve << "g_lambda=" << m_gamma_lambda
698  << " g_zeta=" << m_gamma_zeta
699  << endl;
700  }
701  void SAtrain::UpdateDir(double *pDir, double *pGradient, const double *pdParam)
702  {
703  /* using the momentum */
704  // pdDir is actually the gradient
705 
706  SAfunc* pSA = (SAfunc*)m_pfunc;
707  int nWeightNum = pSA->GetFeatNum();
708  int nZetaNum = pSA->GetZetaNum();
709 
710  lout_assert(nWeightNum + nZetaNum == m_pfunc->GetParamNum());
711 
712 
713 
714 #if defined(_Adam)
715  for (int i = 0; i < nWeightNum; i++) {
716  double g = pGradient[i];
717  adam_m[i] = adam_beta1 * adam_m[i] + (1 - adam_beta1) * g;
718  adam_v[i] = adam_beta2 * adam_v[i] + (1 - adam_beta2)* g*g;
719  double m_hat = adam_m[i] / (1 - pow(adam_beta1, m_nIterNum));
720  double v_hat = adam_v[i] / (1 - pow(adam_beta2, m_nIterNum));
721  pDir[i] = adam_alpha * m_hat / (sqrt(v_hat) + adam_sigma);
722  }
723 #elif defined(_Hession)
724  for (int i = 0; i < nWeightNum; i++) {
725  double h = pSA->m_vSampleExp2[i] - pow(pSA->m_vSampleExp[i],2) + pSA->m_fRegL2;
726  m_avgHes[i] += m_gamma_lambda * (h - m_avgHes[i]);
727  pDir[i] = m_gamma_lambda * pGradient[i] / max(1e-4, m_avgHes[i]);
728  }
729 #else
730  // update lambda
731  for (int i = 0; i < nWeightNum; i++) {
732  //m_avgGrad[i] = 0.9*m_avgGrad[i] + 0.1*pGradient[i];
733  //pDir[i] = m_gamma_lambda * m_avgGrad[i];
734  pDir[i] = m_gamma_lambda * pGradient[i];
735  }
736 
737  if (m_dir_gap > 0) {
738  int n_dgap_cutnum = 0;
739  for (int i = 0; i < nWeightNum; i++) {
740  if (pDir[i] > m_dir_gap) {
741  pDir[i] = m_dir_gap;
742  n_dgap_cutnum++;
743  }
744  else if (pDir[i] < -m_dir_gap) {
745  pDir[i] = -m_dir_gap;
746  n_dgap_cutnum++;
747  }
748  }
749  lout_variable_precent(n_dgap_cutnum, nWeightNum);
750  }
751 #endif
752 
753 
754  // update zeta
755  for (int i = nWeightNum; i < nWeightNum + nZetaNum; i++) {
756  // limit the update of zeta.
757  pDir[i] = min( m_gamma_zeta, 1.0*pSA->m_pModel->GetMaxLen()*pSA->m_pModel->m_pi[i-nWeightNum] ) * pGradient[i];
758  }
759 
760  }
761  void SAtrain::Update(double *pdParam, const double *pdDir, double dStep)
762  {
763  // pdDir is actually the gradient
764 
765  SAfunc* pSA = (SAfunc*)m_pfunc;
766  int nWeightNum = pSA->GetFeatNum();
767  int nZetaNum = pSA->GetZetaNum();
768 
769 // lout_assert(nWeightNum == nNgramFeatNum + nVHsize + nCHsize + nHHsize);
770 
771  // update lambda
772  if (m_bUpdate_lambda) {
773  for (int i = 0; i < nWeightNum; i++) {
774  pdParam[i] += pdDir[i];
775  }
776  }
777 
778 
779 
780  // update zeta
781  if (m_bUpdate_zeta) {
782  for (int i = nWeightNum; i < nWeightNum + nZetaNum; i++) {
783  pdParam[i] += pdDir[i];
784  }
785  double zeta1 = pdParam[nWeightNum + 1];
786  for (int i = nWeightNum + 1; i < nWeightNum + nZetaNum; i++) {
787  pdParam[i] -= zeta1; // minus the zeta[1];
788  }
789  }
790 
791 
792  }
793 
794 #define GAIN_INFO(g) lout<<" "#g"\ttc="<<g.tc<<" beta="<<g.beta<<" t0="<<g.t0<<endl;
796  {
797  lout << "[SATrain] *** Info: ***" << endl;
798  GAIN_INFO(m_gain_lambda);
799  GAIN_INFO(m_gain_zeta);
800  lout << " " << "m_dir_gap=" << m_dir_gap << endl;
801  lout << "[SATrain] *** [End] ***" << endl;
802  }
803 }
Vec< Prob > m_pi
the prior length distribution
Definition: trf-model.h:58
Vec< double > m_vSampleExp2
the sample expectation^2
Definition: trf-sa-train.h:81
void GetParam(double *pdParams)
get the parameters
double Prob
Definition: trf-def.h:28
#define SAFE_DELETE(p)
memory release
Definition: wb-vector.h:49
DataT * Next(KeyT &key)
get next value
Definition: wb-lhash.h:576
const char * Format(const char *p_pMessage,...)
format print to string
Definition: wb-string.cpp:69
a dynamic string class
Definition: wb-string.h:53
void UpdateDir(double *pDir, double *pGradient, const double *pParam)
compute the update direction
virtual void Update(double *pdParam, const double *pdDir, double dStep)
Update the parameters.
int GetFeatNum() const
get the ngram feature number
Definition: trf-sa-train.h:143
bool & bOutputCmd()
if output to the cmd window
Definition: wb-log.cpp:96
clock_t Get()
get the time, but don&#39;t stop recording
Definition: wb-win.cpp:151
void Random(Vocab *pv)
Random.
Definition: trf-feature.cpp:39
pFunc m_nPrintPerIter
#define lout_variable_precent(x, y)
Definition: wb-log.h:180
double m_fRegL2
l2 regularization
Definition: trf-sa-train.h:97
#define lout_assert(p)
Definition: wb-log.h:185
void Fill(T v)
Definition: wb-mat.h:397
void IterEnd(double *pFinalParams)
do something at the end of the SA iteration
virtual void SetParam(double *pdParams)
set the parameter.
void LineNormalize(Prob *pdProbs, int nNum)
Definition: trf-def.cpp:87
void PrintInfo()
Print Information.
virtual int GetExtraValues(int t, double *pdValues)
calculate extra values which will be print at each iteration
clock - used to record the time
Definition: wb-win.h:95
pFunc m_nIterMax
double PValue
Definition: trf-def.h:26
#define GAIN_INFO(g)
Log & output(T *pArray, int n, const char *pgap=" ")
output an array
Definition: wb-log.h:170
void Reset(int p_len)
reset only change the len variable, does not change the buffer size.
Definition: trf-feature.h:51
clock_t Begin()
begin to record
Definition: wb-win.cpp:138
int GetLen() const
Definition: trf-feature.h:71
void Reset(const char *pstr, int p_t0)
#define lout_variable(x)
Definition: wb-log.h:179
void RandSeq(Seq &seq, int nLen=-1)
get a random sequence
void Create(int maxlen, Model *pModel)
int GetZetaNum() const
get the zeta parameter number
Definition: trf-sa-train.h:145
CorpusBase * m_pCorpusTrain
training corpus
Definition: trf-ml-train.h:36
pFunc m_fRegL2
void UpdateGamma(int nIterNum)
Update the learning rate.
static double ToSecond(clock_t t)
transform the clock_t to second
Definition: wb-win.h:115
define a sequence including the word sequence and class sequence
Definition: trf-feature.h:41
virtual bool GetSeq(int nLine, Array< VocabID > &aSeq)=0
get the sequence in nLine
a definition of a class Log, which can output to the cmd window and the log file simultaneously. In wb-log.cpp, there are a Log variable "lout", which can be directly used just like "cout". For example:
void GetEmpVar(CorpusBase *pCorpus, Vec< double > &vVar)
calculate the empirical expectation
void Reset(Model *pModel, CorpusBase *pTrain, CorpusBase *pValid=NULL, CorpusBase *pTest=NULL)
the iter of LHash
Definition: wb-lhash.h:42
TRF model.
Definition: trf-model.h:51
void Fill(T v)
Definition: wb-mat.h:279
int Find(T t)
Find a value and return the position.
Definition: wb-vector.h:248
int GetSize() const
Definition: wb-mat.h:69
int m_nMiniBatchSample
mini-batch for samples
Definition: trf-sa-train.h:70
#define SAFE_DELETE_ARRAY(p)
Definition: wb-vector.h:50
void Set(Array< int > &aInt, Vocab *pv)
transform the word sequence (form file) to Seq
Definition: trf-feature.cpp:22
T * GetBuf() const
Definition: wb-mat.h:68
void Progress(long long n=-1, bool bInit=false, long long total=100, const char *head="")
progress bar
Definition: wb-log.cpp:146
pFunc m_nAvgBeg
Array< VocabID > aSeq
Definition: main-TRF.cpp:153
Vocab * m_pVocab
Definition: trf-model.h:62
int qsort_compare_double(const void *a, const void *b)
void PrintInfo()
print information
double Get(int t)
input the iteration number, get the learning rate
#define lout_warning(x)
Definition: wb-log.h:184
#define lout_variable_rate(x, y)
Definition: wb-log.h:181
int GetNum() const
Get Array number.
Definition: wb-vector.h:240
void Add(T t)
Add a value to the tail of array.
Definition: wb-vector.h:242
Array< Seq * > aSeqs
Definition: trf-sa-train.h:55
virtual void GetSampleExp(VecShell< double > &vExp, VecShell< double > &vExp2, VecShell< double > &vLen)
calcualte the expectation of SA samples
Log lout
the defination is in wb-log.cpp
Definition: wb-log.cpp:22
char * GetBuffer() const
get buffer
Definition: wb-string.h:74
virtual void GetGradient(double *pdGradient)
calculate the gradient g(x)
virtual void Reset(Model *pModel, CorpusBase *pTrain, CorpusBase *pValid=NULL, CorpusBase *pTest=NULL, int nMinibatch=100)
reset
void WriteModel(int nEpoch)
Write Model.
String FileName()
if the string is a path, this function return the file name.
Definition: wb-string.cpp:162
#define lout_Solve
Definition: wb-solve.h:69
Vec< double > m_vSampleExp
the sample expectation
Definition: trf-sa-train.h:80
Definition: trf-alg.cpp:20
Model * m_pModel
HRF model.
Definition: trf-ml-train.h:33
void Copy(VecShell< T > v)
Definition: wb-mat.h:386
int GetMaxLen() const
Get max-len.
Definition: trf-model.h:100
virtual int GetNum() const
get the seq number
Definition: trf-corpus.h:47
virtual bool Run(const double *pInitParams=NULL)
Run iteration. input the init-parameters.