TRF Language Model
main-sa-train.cpp
Go to the documentation of this file.
1 // You may obtain a copy of the License at
2 //
3 // http://www.apache.org/licenses/LICENSE-2.0
4 //
5 // Unless required by applicable law or agreed to in writing, software
6 // distributed under the License is distributed on an "AS IS" BASIS,
7 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 // See the License for the specific language governing permissions and
9 // limitations under the License.
10 //
11 // Copyright 2014-2015 Tsinghua University
12 // Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13 //
14 // All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15 // license declaration. Different coding language may use different comment styles.
16 
17 #ifndef _MLtrain
18 
19 //#include "hrf-sa-train.h"
20 #include "hrf-sams.h"
21 using namespace hrf;
22 
23 char *cfg_pathVocab = NULL;
24 int cfg_nHLayer = 1;
25 int cfg_nHNode = 2;
27 char *cfg_pathFeatStyle = NULL;
28 int cfg_nMaxLen = 0;
29 
30 char *cfg_pathTrain = NULL;
31 char *cfg_pathValid = NULL;
32 char *cfg_pathTest = NULL;
33 
34 char *cfg_pathModelRead = NULL;
35 char *cfg_pathModelWrite = NULL;
36 
37 int cfg_nThread = 1;
38 
39 int cfg_nIterTotalNum = 1000;
40 int cfg_nMiniBatch = 300;
41 int cfg_t0 = 500;
42 char *cfg_gamma_lambda = "0,0.8";
43 char *cfg_gamma_hidden = "100,0.8";
44 char *cfg_gamma_zeta = "0,0.6";
45 char *cfg_gamma_var = "0,0.8";
46 float cfg_fMomentum = 0;
47 float cfg_var_gap = 1e-4;
48 float cfg_dir_gap = 1;
49 float cfg_zeta_gap = 10;
50 bool cfg_bUpdateLambda = false;
51 bool cfg_bUpdateZeta = false;
52 int cfg_nAvgBeg = 0;
53 
54 float cfg_fRegL2 = 0;
55 
56 bool cfg_bInitValue = false;
57 bool cfg_bZeroInit = false;
59 bool cfg_bUnprintTrain = false;
60 bool cfg_bUnprintValid = false;
61 bool cfg_bUnprintTest = false;
62 char *cfg_strWriteAtIter = NULL;
63 
64 char *cfg_pathWriteMean = NULL;
65 char *cfg_pathWriteVar = NULL;
66 
68 
69 Option opt;
70 
71 _wbMain
72 {
73  opt.Add(wbOPT_STRING, "vocab", &cfg_pathVocab, "The vocabulary");
74  opt.Add(wbOPT_STRING, "feat", &cfg_pathFeatStyle, "a feature style file. Set this value will disable -order");
75  opt.Add(wbOPT_INT, "order", &cfg_nFeatOrder, "the ngram feature order");
76  opt.Add(wbOPT_INT, "len", &cfg_nMaxLen, "the maximum length of HRF");
77  opt.Add(wbOPT_INT, "layer", &cfg_nHLayer, "the hidden layer of HRF");
78  opt.Add(wbOPT_INT, "node", &cfg_nHNode, "the hidden node of each hidden layer of HRF");
79  opt.Add(wbOPT_STRING, "train", &cfg_pathTrain, "Training corpus (TXT)");
80  opt.Add(wbOPT_STRING, "valid", &cfg_pathValid, "valid corpus (TXT)");
81  opt.Add(wbOPT_STRING, "test", &cfg_pathTest, "test corpus (TXT)");
82 
83  opt.Add(wbOPT_STRING, "read", &cfg_pathModelRead, "Read the init model to train");
84  opt.Add(wbOPT_STRING, "write", &cfg_pathModelWrite, "Output model");
85 
86  opt.Add(wbOPT_INT, "iter", &cfg_nIterTotalNum, "iter total number");
87  opt.Add(wbOPT_INT, "thread", &cfg_nThread, "The thread number");
88  opt.Add(wbOPT_INT, "mini-batch", &cfg_nMiniBatch, "mini-batch");
89  opt.Add(wbOPT_INT, "t0", &cfg_t0, "t0");
90  opt.Add(wbOPT_STRING, "gamma-lambda", &cfg_gamma_lambda, "learning rate of lambda");
91  opt.Add(wbOPT_STRING, "gamma-hidden", &cfg_gamma_hidden, "learning rate of VHmatrix");
92  opt.Add(wbOPT_STRING, "gamma-zeta", &cfg_gamma_zeta, "learning rate of zeta");
93  opt.Add(wbOPT_STRING, "gamma-var", &cfg_gamma_var, "learning rate of variance");
94  opt.Add(wbOPT_FLOAT, "momentum", &cfg_fMomentum, "the momentum");
95  opt.Add(wbOPT_TRUE, "update-lambda", &cfg_bUpdateLambda, "update lambda");
96  opt.Add(wbOPT_TRUE, "update-zeta", &cfg_bUpdateZeta, "update zeta");
97  opt.Add(wbOPT_INT, "tavg", &cfg_nAvgBeg, ">0 then apply averaging");
98  opt.Add(wbOPT_FLOAT, "vgap", &cfg_var_gap, "the threshold of variance");
99  opt.Add(wbOPT_FLOAT, "dgap", &cfg_dir_gap, "the threshold for parameter update");
100  opt.Add(wbOPT_FLOAT, "zgap", &cfg_zeta_gap, "the threshold for zeta update");
101  opt.Add(wbOPT_FLOAT, "L2", &cfg_fRegL2, "regularization L2");
102 
103  opt.Add(wbOPT_TRUE, "init", &cfg_bInitValue, "Re-init the parameters");
104  opt.Add(wbOPT_TRUE, "zero-init", &cfg_bZeroInit, "Set the init parameters Zero. Otherwise random init the parameters");
105  opt.Add(wbOPT_INT, "print-per-iter", &cfg_nPrintPerIter, "print the LL per iterations");
106  opt.Add(wbOPT_TRUE, "not-print-train", &cfg_bUnprintTrain, "donot print LL on training set");
107  opt.Add(wbOPT_TRUE, "not-print-valid", &cfg_bUnprintValid, "donot print LL on valid set");
108  opt.Add(wbOPT_TRUE, "not-print-test", &cfg_bUnprintTest, "donot print LL on test set");
109  opt.Add(wbOPT_STRING, "write-at-iter", &cfg_strWriteAtIter, "write the LL per iteration, such as [1:100:1000]");
110 
111  opt.Add(wbOPT_STRING, "write-mean", &cfg_pathWriteMean, "write the expecataion on training set");
112  opt.Add(wbOPT_STRING, "write-var", &cfg_pathWriteVar, "write the variance on training set");
113 
114  opt.Add(wbOPT_INT, "AIS-chain", &cfg_AIS_for_LL.nChain, "AIS chain number");
115  opt.Add(wbOPT_INT, "AIS-inter", &cfg_AIS_for_LL.nInter, "AIS intermediate distribution number");
116 
117  opt.Parse(_argc, _argv);
118 
119  lout << "*********************************************" << endl;
120  lout << " TRF_SAtrain.exe { by Bin Wang } " << endl;
121  lout << "\t" << __DATE__ << "\t" << __TIME__ << "\t" << endl;
122  lout << "**********************************************" << endl;
123 
124  omp_set_num_threads(cfg_nThread);
125  lout << "[OMP] omp_thread = " << omp_get_max_threads() << endl;
127 
128  /* Load Model and Vocab */
129  Vocab *pv = new Vocab(cfg_pathVocab);
131  if (cfg_pathModelRead) {
133  }
134  else {
136  }
140 
141  /* Load corpus */
142  trf::CorpusTxt *pTrain = (cfg_pathTrain) ? new trf::CorpusTxt(cfg_pathTrain) : NULL;
143  trf::CorpusTxt *pValid = (cfg_pathValid) ? new trf::CorpusTxt(cfg_pathValid) : NULL;
144  trf::CorpusTxt *pTest = (cfg_pathTest) ? new trf::CorpusTxt(cfg_pathTest) : NULL;
145 
146 
147  Train *pFunc;
148  if (cfg_bUpdateZeta) {
149  pFunc = new SAMSZeta;
150  }
151  else if (cfg_bUpdateLambda) {
152  pFunc = new SALambda;
153  }
154 
155  pFunc->OpenTempFile(cfg_pathModelWrite);
156  pFunc->Reset(&m, pTrain, pValid, pTest);
157  pFunc->m_nMinibatch = cfg_nMiniBatch;
158  pFunc->m_nAvgBeg = cfg_nAvgBeg;
159  pFunc->m_fRegL2 = cfg_fRegL2;
160  pFunc->m_aPrint[0] = !cfg_bUnprintTrain;
161  pFunc->m_aPrint[1] = !cfg_bUnprintValid;
162  pFunc->m_aPrint[2] = !cfg_bUnprintTest;
163  pFunc->m_nPrintPerIter = cfg_nPrintPerIter;
164  VecUnfold(cfg_strWriteAtIter, pFunc->m_aWriteAtIter);
165  pFunc->m_nIterMax = cfg_nIterTotalNum; // fix the iteration number
166  pFunc->m_AISConfigForP = cfg_AIS_for_LL;
167  pFunc->m_AISConfigForZ = cfg_AIS_for_LL;
168 
170  SAMSZeta *p = (SAMSZeta*)pFunc;
171  p->m_zeta_rate.Reset(cfg_gamma_zeta, cfg_t0);
172  p->m_zeta_gap = cfg_zeta_gap;
173  }
174  else if (cfg_bUpdateLambda) {
175  SALambda *p = (SALambda*)pFunc;
176  p->m_feat_rate.Reset(cfg_gamma_lambda, cfg_t0);
177  p->m_hidden_rate.Reset(cfg_gamma_hidden, cfg_t0);
178  p->m_dir_gap = cfg_dir_gap;
179 #ifdef _Var
180  p->m_var_rate.Reset(cfg_gamma_var, cfg_t0);
181  p->m_var_gap = cfg_var_gap;
182 #endif
183  }
184 
185  /* set initial values */
187  pFunc->Run(bInit);
188 
189  // Finish
191 
192  SAFE_DELETE(pTrain);
193  SAFE_DELETE(pValid);
194  SAFE_DELETE(pTest);
195 
196  SAFE_DELETE(pv);
197 
198  return 1;
199 
200 }
201 
202 
203 #endif
int GetParamNum() const
Get the total parameter number.
Definition: hrf-model.h:130
trf::Vocab Vocab
Definition: hrf-model.h:28
int cfg_t0
int nChain
chain number
Definition: hrf-sa-train.h:16
Option opt
char * cfg_pathTrain
int m_hnode
the number of hidden nodes
Definition: hrf-model.h:102
bool cfg_bUpdateZeta
char * cfg_pathModelRead
is true if exist
Definition: wb-option.h:33
int cfg_nHLayer
float cfg_var_gap
virtual void Reset(const char *pfilename)
Open file and Load the file.
Definition: trf-corpus.cpp:29
hidden-random-field model
Definition: hrf-model.h:98
int cfg_nMiniBatch
int nInter
intermediate distribution number
Definition: hrf-sa-train.h:17
void LoadFromCorpus(const char *pcorpus, const char *pfeatstyle, int nOrder)
load ngram features from corpus
Definition: trf-model.cpp:95
SAFE_DELETE(pTrain)
float cfg_zeta_gap
bool cfg_bUnprintTrain
bool cfg_bZeroInit
bool cfg_bUpdateLambda
int cfg_nFeatOrder
#define lout_variable(x)
Definition: wb-log.h:179
AISConfig cfg_AIS_for_LL
bool cfg_bUnprintValid
int cfg_nAvgBeg
int cfg_nPrintPerIter
bool bInit
_wbMain
void ReadT(const char *pfilename)
Read Model.
Definition: hrf-model.cpp:149
integer
Definition: wb-option.h:35
char * cfg_pathTest
float cfg_dir_gap
char * cfg_gamma_hidden
char * cfg_gamma_lambda
float cfg_fRegL2
int omp_rand(int thread_num)
Definition: trf-def.cpp:23
int cfg_nThread
float cfg_fMomentum
char * cfg_strWriteAtIter
VecUnfold(cfg_strWriteAtIter, pFunc->m_aWriteAtIter)
int cfg_nIterTotalNum
char * cfg_pathValid
int m_hlayer
the number of hidden layer
Definition: hrf-model.h:101
bool cfg_bUnprintTest
int cfg_nMaxLen
pFunc Reset & m
Log lout
the defination is in wb-log.cpp
Definition: wb-log.cpp:22
char * cfg_pathFeatStyle
char * cfg_pathModelWrite
char * cfg_pathWriteMean
bool cfg_bInitValue
char * cfg_pathVocab
char * cfg_gamma_zeta
int cfg_nHNode
char * cfg_pathWriteVar
char * cfg_gamma_var
void WriteT(const char *pfilename)
Write Model.
Definition: hrf-model.cpp:233