TRF Language Model
main-SA-train.cpp
Go to the documentation of this file.
1 // You may obtain a copy of the License at
2 //
3 // http://www.apache.org/licenses/LICENSE-2.0
4 //
5 // Unless required by applicable law or agreed to in writing, software
6 // distributed under the License is distributed on an "AS IS" BASIS,
7 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 // See the License for the specific language governing permissions and
9 // limitations under the License.
10 //
11 // Copyright 2014-2015 Tsinghua University
12 // Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13 //
14 // All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15 // license declaration. Different coding language may use different comment styles.
16 
17 
18 #ifndef _MLtrain
19 #include "trf-model.h"
20 #include "trf-sa-train.h"
21 #include <omp.h>
22 using namespace trf;
23 
24 char *cfg_pathVocab = NULL;
25 
27 char *cfg_pathFeatStyle = NULL;
28 int cfg_nMaxLen = 0;
29 
30 char *cfg_pathTrain = NULL;
31 char *cfg_pathValid = NULL;
32 char *cfg_pathTest = NULL;
33 
34 char *cfg_pathModelRead = NULL;
35 char *cfg_pathModelWrite = NULL;
36 
37 int cfg_nThread = 1;
38 
39 int cfg_nIterTotalNum = 1000;
40 int cfg_nMiniBatch = 300;
41 int cfg_t0 = 500;
42 char *cfg_gamma_lambda = NULL;
43 char *cfg_gamma_zeta = NULL;
44 bool cfg_bUnupdateLambda = false;
45 bool cfg_bUnupdateZeta = false;
46 int cfg_nAvgBeg = 0;
47 
48 float cfg_fRegL2 = 0;
49 float cfg_dGap = 1.0f;
50 float cfg_vGap = 1e-5;
51 
52 bool cfg_bInitValue = false;
54 char *cfg_strWriteAtIter = NULL;
55 
56 char *cfg_pathWriteMean = NULL;
57 char *cfg_pathWriteVar = NULL;
58 char *cfg_pathWriteLLtrain = NULL;
59 char *cfg_pathWriteLLvalid = NULL;
60 char *cfg_pathWriteLLtest = NULL;
61 
63 
64 _wbMain
65 {
66  opt.Add(wbOPT_STRING, "vocab", &cfg_pathVocab, "The vocabulary");
67  opt.Add(wbOPT_STRING, "feat", &cfg_pathFeatStyle, "a feature style file. Set this value will disable -order");
68  opt.Add(wbOPT_INT, "order", &cfg_nFeatOrder, "the ngram feature order");
69  opt.Add(wbOPT_INT, "len", &cfg_nMaxLen, "the maximum length of TRF");
70  opt.Add(wbOPT_STRING, "train", &cfg_pathTrain, "Training corpus (TXT)");
71  opt.Add(wbOPT_STRING, "valid", &cfg_pathValid, "valid corpus (TXT)");
72  opt.Add(wbOPT_STRING, "test", &cfg_pathTest, "test corpus (TXT)");
73 
74  opt.Add(wbOPT_STRING, "read", &cfg_pathModelRead, "Read the init model to train");
75  opt.Add(wbOPT_STRING, "write", &cfg_pathModelWrite, "Output model");
76 
77  opt.Add(wbOPT_INT, "iter", &cfg_nIterTotalNum, "iter total number");
78  opt.Add(wbOPT_INT, "thread", &cfg_nThread, "The thread number");
79  opt.Add(wbOPT_INT, "mini-batch", &cfg_nMiniBatch, "mini-batch");
80  opt.Add(wbOPT_INT, "t0", &cfg_t0, "t0");
81  opt.Add(wbOPT_STRING, "gamma-lambda", &cfg_gamma_lambda, "learning rate of lambda");
82  opt.Add(wbOPT_STRING, "gamma-zeta", &cfg_gamma_zeta, "learning rate of zeta");
83  opt.Add(wbOPT_TRUE, "unupdate-lambda", &cfg_bUnupdateLambda, "don't update lambda");
84  opt.Add(wbOPT_TRUE, "unupdate-zeta", &cfg_bUnupdateZeta, "don't update zeta");
85  opt.Add(wbOPT_INT, "tavg", &cfg_nAvgBeg, ">0 then apply averaging");
86  opt.Add(wbOPT_FLOAT, "L2", &cfg_fRegL2, "regularization L2");
87  opt.Add(wbOPT_FLOAT, "dgap", &cfg_dGap, "the gap for update value at each iteration");
88  opt.Add(wbOPT_FALSE, "vgap", &cfg_vGap, "the gap for empirical variance");
89 
90  opt.Add(wbOPT_TRUE, "init", &cfg_bInitValue, "Re-init the parameters");
91  opt.Add(wbOPT_INT, "print-per-iter", &cfg_nPrintPerIter, "print the LL per iterations");
92  opt.Add(wbOPT_STRING, "write-at-iter", &cfg_strWriteAtIter, "write models at iteration, such as [1:100:1000]");
93 
94  opt.Add(wbOPT_STRING, "write-mean", &cfg_pathWriteMean, "write the expecataion on training set");
95  opt.Add(wbOPT_STRING, "write-var", &cfg_pathWriteVar, "write the variance on training set");
96  opt.Add(wbOPT_STRING, "write-train-ll", &cfg_pathWriteLLtrain, "write LL on train");
97  opt.Add(wbOPT_STRING, "write-valid-ll", &cfg_pathWriteLLvalid, "write LL on valid");
98  opt.Add(wbOPT_STRING, "write-test-ll", &cfg_pathWriteLLtest, "write LL on test");
99 
100 
101  opt.Parse(_argc, _argv);
102 
103  lout << "*********************************************" << endl;
104  lout << " TRF_SAtrain.exe { by Bin Wang } " << endl;
105  lout << "\t" << __DATE__ << "\t" << __TIME__ << "\t" << endl;
106  lout << "**********************************************" << endl;
107 
108 
109  omp_set_num_threads(cfg_nThread);
110  lout << "[OMP] omp_thread = " << omp_get_max_threads() << endl;
112  Title::SetGlobalTitle(String(cfg_pathModelWrite).FileName());
113 
114  Vocab *pv = new Vocab(cfg_pathVocab);
115  Model m(pv, cfg_nMaxLen);
116  if (cfg_pathModelRead) {
118  }
119  else {
121  }
122 
123  CorpusTxt *pTrain = (cfg_pathTrain) ? new CorpusTxt(cfg_pathTrain) : NULL;
124  CorpusTxt *pValid = (cfg_pathValid) ? new CorpusTxt(cfg_pathValid) : NULL;
125  CorpusTxt *pTest = (cfg_pathTest) ? new CorpusTxt(cfg_pathTest) : NULL;
126 
127  SAfunc func;
128  func.m_fdbg.Open(String(cfg_pathModelWrite).FileName() + ".sadbg", "wt");
129  func.m_fmean.Open(cfg_pathWriteMean, "wt");
130  func.m_fvar.Open(cfg_pathWriteVar, "wt");
133  func.m_ftestLL.Open(cfg_pathWriteLLtest, "wt");
134 // func.m_fparm.Open(String(cfg_pathModelWrite).FileName() + ".parm", "wt");
135 // func.m_fgrad.Open(String(cfg_pathModelWrite).FileName() + ".grad", "wt");
136 // func.m_fexp.Open(String(cfg_pathModelWrite).FileName() + ".expt", "wt");
137 // func.m_fsamp.Open(String(cfg_pathModelWrite).FileName() + ".samp", "wt");
138 // func.m_ftrain.Open(String(cfg_pathModelWrite).FileName() + ".train", "wt");
140  func.Reset(&m, pTrain, pValid, pTest, cfg_nMiniBatch);
141  func.m_fRegL2 = cfg_fRegL2;
142  func.m_var_gap = cfg_vGap;
143  func.PrintInfo();
144 
145 
146 
147  SAtrain solve(&func);
148  solve.m_nIterMax = cfg_nIterTotalNum; // fix the iteration number
153  solve.m_nAvgBeg = cfg_nAvgBeg;
155  solve.m_dir_gap = cfg_dGap;
157  solve.PrintInfo();
158 
159 
160  /* set initial values */
161  bool bInitWeight = (!cfg_pathModelRead) || (cfg_bInitValue && !cfg_bUnupdateLambda);
162  bool bInitZeta = (!cfg_pathModelRead) || (cfg_bInitValue && !cfg_bUnupdateZeta);
163 
167  func.GetParam(vInitParams.GetBuf());
168  }
169  if (bInitWeight) {
170  lout << "[Init Parameters] Zero" << endl;
171  vInitParams.Fill(0);
172  }
173  if (bInitZeta) {
174  for (int i = 0; i <= m.GetMaxLen(); i++) {
175  vInitParams[m.GetParamNum() + i] = max(0, i - 1) * log(m.m_pVocab->GetSize()); // set zeta
176  }
177  }
178 
179 
180  solve.Run(vInitParams.GetBuf());
181 
182  // Finish
184 
185  SAFE_DELETE(pTrain);
186  SAFE_DELETE(pValid);
187  SAFE_DELETE(pTest);
188 
189  SAFE_DELETE(pv);
190 
191  return 1;
192 
193 
194 
195 }
196 
197 #endif
int m_nAvgBeg
if >0, then calculate the average
Definition: trf-sa-train.h:203
trf::Vocab Vocab
Definition: hrf-model.h:28
char * cfg_pathTrain
int GetParamNum() const
Get parameter number.
Definition: trf-model.h:106
void GetParam(double *pdParams)
get the parameters
char * cfg_pathWriteMean
a dynamic string class
Definition: wb-string.h:53
const char * m_pathOutputModel
Write to model during iteration.
Definition: trf-ml-train.h:44
void ReadT(const char *pfilename)
Read Model.
Definition: trf-model.cpp:114
char * cfg_gamma_zeta
int cfg_nMaxLen
SAFE_DELETE(pTrain)
float cfg_dGap
double m_fRegL2
l2 regularization
Definition: trf-sa-train.h:97
is true if exist
Definition: wb-option.h:33
char * cfg_pathTest
char * cfg_strWriteAtIter
int cfg_nFeatOrder
int cfg_t0
void PrintInfo()
Print Information.
void Parse(const char *plabel, const char *pvalue)
parse a single option, "pvalue" can be NULL
Definition: wb-option.cpp:80
File m_fmean
output the p[f] on training set
Definition: trf-sa-train.h:109
void LoadFromCorpus(const char *pcorpus, const char *pfeatstyle, int nOrder)
load ngram features from corpus
Definition: trf-model.cpp:95
int cfg_nIterTotalNum
bool cfg_bUnupdateLambda
char * cfg_pathVocab
int cfg_nPrintPerIter
void Reset(const char *pstr, int p_t0)
char * cfg_gamma_lambda
_wbMain
double m_var_gap
a varicance gap used in gradient sacling
Definition: trf-sa-train.h:98
File m_fvallidLL
output loglikelihood on valid set
Definition: trf-sa-train.h:115
File m_ftrainLL
output loglikelihood on training set
Definition: trf-sa-train.h:114
lout<< "*********************************************"<< endl;lout<< " TRF_SAtrain.exe { by Bin Wang } "<< endl;lout<< "\"<< __DATE__<< "\"<< __TIME__<< "\"<< endl;lout<< "**********************************************"<< endl;omp_set_num_threads(cfg_nThread);lout<< "[OMP] omp_thread = "<< omp_get_max_threads()<< endl;trf::omp_rand(cfg_nThread);Title::SetGlobalTitle(String(cfg_pathModelWrite).FileName());Vocab *pv=new Vocab(cfg_pathVocab);Model m(pv, cfg_nMaxLen);if(cfg_pathModelRead) { m.ReadT(cfg_pathModelRead);} else { m.LoadFromCorpus(cfg_pathTrain, cfg_pathFeatStyle, cfg_nFeatOrder);} CorpusTxt *pTrain=(cfg_pathTrain) ? new CorpusTxt(cfg_pathTrain) :NULL;CorpusTxt *pValid=(cfg_pathValid) ? new CorpusTxt(cfg_pathValid) :NULL;CorpusTxt *pTest=(cfg_pathTest) ? new CorpusTxt(cfg_pathTest) :NULL;SAfunc func;func.m_fdbg.Open(String(cfg_pathModelWrite).FileName()+".sadbg", "wt");func.m_fmean.Open(cfg_pathWriteMean, "wt");func.m_fvar.Open(cfg_pathWriteVar, "wt");func.m_ftrainLL.Open(cfg_pathWriteLLtrain, "wt");func.m_fvallidLL.Open(cfg_pathWriteLLvalid, "wt");func.m_ftestLL.Open(cfg_pathWriteLLtest, "wt");func.m_pathOutputModel=cfg_pathModelWrite;func.Reset(&m, pTrain, pValid, pTest, cfg_nMiniBatch);func.m_fRegL2=cfg_fRegL2;func.m_var_gap=cfg_vGap;func.PrintInfo();SAtrain solve(&func);solve.m_nIterMax=cfg_nIterTotalNum;solve.m_gain_lambda.Reset(cfg_gamma_lambda ? cfg_gamma_lambda :"0,0", cfg_t0);solve.m_gain_zeta.Reset(cfg_gamma_zeta ? cfg_gamma_zeta :"0,0.6", cfg_t0);solve.m_bUpdate_lambda=!cfg_bUnupdateLambda;solve.m_bUpdate_zeta=!cfg_bUnupdateZeta;solve.m_nAvgBeg=cfg_nAvgBeg;solve.m_nPrintPerIter=cfg_nPrintPerIter;solve.m_dir_gap=cfg_dGap;VecUnfold(cfg_strWriteAtIter, solve.m_aWriteAtIter);solve.PrintInfo();bool bInitWeight=(!cfg_pathModelRead)||(cfg_bInitValue &&!cfg_bUnupdateLambda);bool bInitZeta=(!cfg_pathModelRead)||(cfg_bInitValue &&!cfg_bUnupdateZeta);Vec< double > vInitParams(func.GetParamNum())
wb::Array< int > m_aWriteAtIter
output temp model at some iteration
Definition: trf-sa-train.h:206
set false if exist
Definition: wb-option.h:34
char * cfg_pathWriteVar
File m_fvar
output the variance at each iteration
Definition: trf-sa-train.h:110
char * cfg_pathWriteLLvalid
integer
Definition: wb-option.h:35
virtual bool Open(const char *path, const char *mode, bool bHardOpen=true)
Open file.
Definition: wb-file.cpp:23
TRF model.
Definition: trf-model.h:51
float cfg_vGap
int m_nIterMax
maximum iteration number
Definition: wb-solve.h:88
int omp_rand(int thread_num)
Definition: trf-def.cpp:23
Vocab * m_pVocab
Definition: trf-model.h:62
void PrintInfo()
print information
int GetParamNum() const
get the paremeter number
Definition: wb-solve.h:52
File m_fdbg
output the sample pi/zete information
Definition: trf-sa-train.h:106
VecUnfold(cfg_strWriteAtIter, pFunc->m_aWriteAtIter)
float cfg_fRegL2
bool m_bUpdate_lambda
Definition: trf-sa-train.h:199
char * cfg_pathFeatStyle
bool m_bUpdate_zeta
Definition: trf-sa-train.h:200
LearningRate m_gain_lambda
Definition: trf-sa-train.h:195
void Add(ValueType t, const char *pLabel, void *pAddress, const char *pDocMsg=NULL)
Add a option.
Definition: wb-option.cpp:35
int GetSize()
get the vocab size, i.e. the word number
Definition: trf-vocab.h:47
pFunc Reset & m
Log lout
the defination is in wb-log.cpp
Definition: wb-log.cpp:22
char * cfg_pathWriteLLtest
char * cfg_pathValid
double m_dir_gap
control the dir values
Definition: trf-sa-train.h:201
char * cfg_pathWriteLLtrain
virtual void Reset(Model *pModel, CorpusBase *pTrain, CorpusBase *pValid=NULL, CorpusBase *pTest=NULL, int nMinibatch=100)
reset
int m_nPrintPerIter
output the LL per iteration, if ==-1, the disable
Definition: trf-sa-train.h:205
bool cfg_bInitValue
LearningRate m_gain_zeta
Definition: trf-sa-train.h:196
int cfg_nMiniBatch
bool cfg_bUnupdateZeta
void WriteT(const char *pfilename)
Write Model.
Definition: trf-model.cpp:158
Option opt
File m_ftestLL
output loglikelihood on test set
Definition: trf-sa-train.h:116
Definition: trf-alg.cpp:20
int cfg_nAvgBeg
int GetMaxLen() const
Get max-len.
Definition: trf-model.h:100
char * cfg_pathModelRead
Get the option from command line or command files.
Definition: wb-option.h:54
int cfg_nThread
virtual bool Run(const double *pInitParams=NULL)
Run iteration. input the init-parameters.
char * cfg_pathModelWrite