TRF Language Model
main-ML-train.cpp
Go to the documentation of this file.
1 // You may obtain a copy of the License at
2 //
3 // http://www.apache.org/licenses/LICENSE-2.0
4 //
5 // Unless required by applicable law or agreed to in writing, software
6 // distributed under the License is distributed on an "AS IS" BASIS,
7 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 // See the License for the specific language governing permissions and
9 // limitations under the License.
10 //
11 // Copyright 2014-2015 Tsinghua University
12 // Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13 //
14 // All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15 // license declaration. Different coding language may use different comment styles.
16 
17 
18 #ifdef _MLtrain
19 
20 #include "trf-model.h"
21 #include "trf-ml-train.h"
22 #include <omp.h>
23 using namespace trf;
24 
25 char *cfg_pathVocab = NULL;
26 
27 int cfg_nFeatOrder = 2;
28 int cfg_nHNode = 2;
29 int cfg_nMaxLen = 0;
30 
31 char *cfg_pathTrain = NULL;
32 char *cfg_pathValid = NULL;
33 char *cfg_pathTest = NULL;
34 
35 char *cfg_pathModelRead = NULL;
36 char *cfg_pathModelWrite = "test.model";
37 
38 int cfg_nIterTotalNum = 1000;
39 int cfg_nThread = 1;
40 
41 Option opt;
42 
43 _wbMain
44 {
45  opt.Add(wbOPT_STRING, "vocab", &cfg_pathVocab, "The vocabulary");
46  opt.Add(wbOPT_INT, "order", &cfg_nFeatOrder, "the ngram feature order (default=2)");
47  opt.Add(wbOPT_STRING, "train", &cfg_pathTrain, "Training corpus (TXT)");
48  opt.Add(wbOPT_STRING, "valid", &cfg_pathValid, "valid corpus (TXT)");
49  opt.Add(wbOPT_STRING, "test", &cfg_pathTest, "test corpus (TXT)");
50 
51  opt.Add(wbOPT_STRING, "read", &cfg_pathModelRead, "Read the init model to train");
52  opt.Add(wbOPT_STRING, "write", &cfg_pathModelWrite, "Output model");
53 
54  opt.Add(wbOPT_INT, "iter", &cfg_nIterTotalNum, "iter total number");
55  opt.Add(wbOPT_INT, "thread", &cfg_nThread, "The thread number");
56 
57  opt.Parse(_argc, _argv);
58 
59  lout << "*********************************************" << endl;
60  lout << " TRF_train.exe { by Bin Wang } " << endl;
61  lout << "\t" << __DATE__ << "\t" << __TIME__ << "\t" << endl;
62  lout << "**********************************************" << endl;
63 
64  omp_set_num_threads(cfg_nThread);
65  lout << "[OMP] omp_thread = " << omp_get_max_threads() << endl;
66  Title::SetGlobalTitle(String(cfg_pathModelWrite).FileName());
67 
68  Vocab *pv = new Vocab(cfg_pathVocab);
69  Model m(pv, 0);
70  if (cfg_pathModelRead) {
71  m.ReadT(cfg_pathModelRead);
72  }
73  else {
74  m.LoadFromCorpus(cfg_pathTrain, NULL, cfg_nFeatOrder);
75  }
76  m.WriteT(cfg_pathModelWrite);
77 
78  CorpusTxt *pTrain = (cfg_pathTrain) ? new CorpusTxt(cfg_pathTrain) : NULL;
79  CorpusTxt *pValid = (cfg_pathValid) ? new CorpusTxt(cfg_pathValid) : NULL;
80  CorpusTxt *pTest = (cfg_pathTest) ? new CorpusTxt(cfg_pathTest) : NULL;
81 
82  MLfunc func(&m, pTrain, pValid, pTest);
83  func.m_pathOutputModel = cfg_pathModelWrite;
84 
85 
86 // Vec<double> vParams(func.GetParamNum());
87 // vParams.Fill(0.1);
88 //
89 // Vec<double> vG1(func.GetParamNum());
90 // Vec<double> vG2(func.GetParamNum());
91 //
92 // func.SetParam(vParams.GetBuf());
93 // func.GetGradient(vG1.GetBuf());
94 //
95 // Vec<double> vP2(func.GetParamNum());
96 // double delta = 1e-3;
97 // for (int i = 2053; i < func.GetParamNum(); i++) {
98 // vP2.Copy(vParams);
99 // func.SetParam(vP2.GetBuf());
100 // double f1 = func.GetValue();
101 // vP2[i] += delta;
102 // func.SetParam(vP2.GetBuf());
103 // double f2 = func.GetValue();
104 // double g = (f2 - f1) / delta;
105 // lout << g << " " << vG1[i] << endl;
106 // }
107 // return 1;
108 
109  wb::LBFGS solve(&func);
110  solve.m_nIterMax = cfg_nIterTotalNum; // fix the iteration number
111  //solve.m_dGain = 1; // fixed the gain
112 
113  Vec<double> vInitParams(func.GetParamNum());
114  vInitParams.Fill(0);
115 
116  if (cfg_pathModelRead) {
117  func.GetParam(vInitParams.GetBuf());
118  }
119  else {
120  vInitParams.Fill(0);
121  }
122 
123 
124  solve.Run(vInitParams.GetBuf());
125 
126  // Finish
127  func.SetParam(solve.m_pdRoot);
128  m.WriteT(cfg_pathModelWrite);
129 
130  SAFE_DELETE(pTrain);
131  SAFE_DELETE(pValid);
132  SAFE_DELETE(pTest);
133 
134  SAFE_DELETE(pv);
135 
136  return 1;
137 
138 
139 
140 }
141 
142 
143 #endif
trf::Vocab Vocab
Definition: hrf-model.h:28
#define SAFE_DELETE(p)
memory release
Definition: wb-vector.h:49
Option opt
char * cfg_pathTrain
char * cfg_pathModelRead
#define _wbMain
define the main function
Definition: wb-system.h:47
int cfg_nFeatOrder
lout<< "*********************************************"<< endl;lout<< " TRF_SAtrain.exe { by Bin Wang } "<< endl;lout<< "\"<< __DATE__<< "\"<< __TIME__<< "\"<< endl;lout<< "**********************************************"<< endl;omp_set_num_threads(cfg_nThread);lout<< "[OMP] omp_thread = "<< omp_get_max_threads()<< endl;trf::omp_rand(cfg_nThread);Title::SetGlobalTitle(String(cfg_pathModelWrite).FileName());Vocab *pv=new Vocab(cfg_pathVocab);Model m(pv, cfg_nMaxLen);if(cfg_pathModelRead) { m.ReadT(cfg_pathModelRead);} else { m.LoadFromCorpus(cfg_pathTrain, cfg_pathFeatStyle, cfg_nFeatOrder);} CorpusTxt *pTrain=(cfg_pathTrain) ? new CorpusTxt(cfg_pathTrain) :NULL;CorpusTxt *pValid=(cfg_pathValid) ? new CorpusTxt(cfg_pathValid) :NULL;CorpusTxt *pTest=(cfg_pathTest) ? new CorpusTxt(cfg_pathTest) :NULL;SAfunc func;func.m_fdbg.Open(String(cfg_pathModelWrite).FileName()+".sadbg", "wt");func.m_fmean.Open(cfg_pathWriteMean, "wt");func.m_fvar.Open(cfg_pathWriteVar, "wt");func.m_ftrainLL.Open(cfg_pathWriteLLtrain, "wt");func.m_fvallidLL.Open(cfg_pathWriteLLvalid, "wt");func.m_ftestLL.Open(cfg_pathWriteLLtest, "wt");func.m_pathOutputModel=cfg_pathModelWrite;func.Reset(&m, pTrain, pValid, pTest, cfg_nMiniBatch);func.m_fRegL2=cfg_fRegL2;func.m_var_gap=cfg_vGap;func.PrintInfo();SAtrain solve(&func);solve.m_nIterMax=cfg_nIterTotalNum;solve.m_gain_lambda.Reset(cfg_gamma_lambda ? cfg_gamma_lambda :"0,0", cfg_t0);solve.m_gain_zeta.Reset(cfg_gamma_zeta ? cfg_gamma_zeta :"0,0.6", cfg_t0);solve.m_bUpdate_lambda=!cfg_bUnupdateLambda;solve.m_bUpdate_zeta=!cfg_bUnupdateZeta;solve.m_nAvgBeg=cfg_nAvgBeg;solve.m_nPrintPerIter=cfg_nPrintPerIter;solve.m_dir_gap=cfg_dGap;VecUnfold(cfg_strWriteAtIter, solve.m_aWriteAtIter);solve.PrintInfo();bool bInitWeight=(!cfg_pathModelRead)||(cfg_bInitValue &&!cfg_bUnupdateLambda);bool bInitZeta=(!cfg_pathModelRead)||(cfg_bInitValue &&!cfg_bUnupdateZeta);Vec< double > vInitParams(func.GetParamNum())
integer
Definition: wb-option.h:35
char * cfg_pathTest
TRF model.
Definition: trf-model.h:51
int cfg_nThread
int cfg_nIterTotalNum
char * cfg_pathValid
int cfg_nMaxLen
pFunc Reset & m
Log lout
the defination is in wb-log.cpp
Definition: wb-log.cpp:22
char * cfg_pathModelWrite
char * cfg_pathVocab
Definition: trf-alg.cpp:20
int cfg_nHNode