TRF Language Model
main-ml-train.cpp
Go to the documentation of this file.
1 // You may obtain a copy of the License at
2 //
3 // http://www.apache.org/licenses/LICENSE-2.0
4 //
5 // Unless required by applicable law or agreed to in writing, software
6 // distributed under the License is distributed on an "AS IS" BASIS,
7 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 // See the License for the specific language governing permissions and
9 // limitations under the License.
10 //
11 // Copyright 2014-2015 Tsinghua University
12 // Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13 //
14 // All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15 // license declaration. Different coding language may use different comment styles.
16 
17 #ifdef _MLtrain
18 
19 
20 #include "hrf-code-exam.h"
21 #include "hrf-ml-train.h"
22 #include <omp.h>
23 using namespace hrf;
24 
25 char *cfg_pathVocab = NULL;
26 
27 int cfg_nFeatOrder = 2;
28 int cfg_nHLayer = 1;
29 int cfg_nHNode = 2;
30 
31 char *cfg_pathTrain = NULL;
32 char *cfg_pathValid = NULL;
33 char *cfg_pathTest = NULL;
34 
35 char *cfg_pathModelRead = NULL;
36 char *cfg_pathModelWrite = "test.model";
37 
38 int cfg_nIterTotalNum = 100;
39 int cfg_nThread = 1;
40 
41 Option opt;
42 
43 _wbMain
44 {
45  opt.Add(wbOPT_STRING, "vocab", &cfg_pathVocab, "The vocabulary");
46  opt.Add(wbOPT_INT, "order", &cfg_nFeatOrder, "the ngram feature order (default=2)");
47  opt.Add(wbOPT_INT, "layer", &cfg_nHLayer, "the hidden layer of HRF");
48  opt.Add(wbOPT_INT, "node", &cfg_nHNode, "the hidden node of each hidden layer of HRF");
49  opt.Add(wbOPT_STRING, "train", &cfg_pathTrain, "Training corpus (TXT)");
50  opt.Add(wbOPT_STRING, "valid", &cfg_pathValid, "valid corpus (TXT)");
51  opt.Add(wbOPT_STRING, "test", &cfg_pathTest, "test corpus (TXT)");
52 
53  opt.Add(wbOPT_STRING, "read", &cfg_pathModelRead, "Read the init model to train");
54  opt.Add(wbOPT_STRING, "write", &cfg_pathModelWrite, "Output model");
55 
56  opt.Add(wbOPT_INT, "iter", &cfg_nIterTotalNum, "iter total number");
57  opt.Add(wbOPT_INT, "thread", &cfg_nThread, "The thread number");
58 
59  opt.Parse(_argc, _argv);
60 
61  lout << "*********************************************" << endl;
62  lout << " TRF_train.exe { by Bin Wang } " << endl;
63  lout << "\t" << __DATE__ << "\t" << __TIME__ << "\t" << endl;
64  lout << "**********************************************" << endl;
65 
66  omp_set_num_threads(cfg_nThread);
67  lout << "[OMP] omp_thread = " << omp_get_max_threads() << endl;
68 
69 
70  Vocab *pv = new Vocab(cfg_pathVocab);
71  Model m(pv, cfg_nHLayer, cfg_nHNode, 4);
72  if (cfg_pathModelRead) {
73  m.ReadT(cfg_pathModelRead);
74  }
75  else {
76  m.LoadFromCorpus(cfg_pathTrain, NULL, cfg_nFeatOrder);
77  }
78  lout_variable(m.GetParamNum());
79 
80 
81  /* Exam Model */
82  ModelExam exam(&m);
83  exam.SetValueRand();
84  //exam.SetValueAll(0);
85  exam.TestNormalization(2);
86  exam.TestExpectation(2);
87  exam.TestHiddenExp(2);
88  exam.TestSample();
89  return 1;
90 
91 
92  trf::CorpusTxt *pTrain = (cfg_pathTrain) ? new trf::CorpusTxt(cfg_pathTrain) : NULL;
93  trf::CorpusTxt *pValid = (cfg_pathValid) ? new trf::CorpusTxt(cfg_pathValid) : NULL;
94  trf::CorpusTxt *pTest = (cfg_pathTest) ? new trf::CorpusTxt(cfg_pathTest) : NULL;
95 
96  MLfunc func(&m, pTrain, pValid, pTest);
97  func.m_pathOutputModel = cfg_pathModelWrite;
98 
99  wb::LBFGS solve(&func);
100  solve.m_nIterMax = cfg_nIterTotalNum; // fix the iteration number
101  //solve.m_dGain = 1; // fixed the gain
102 
103  Vec<double> vInitParams(func.GetParamNum());
104  vInitParams.Fill(0);
105 
106  if (cfg_pathModelRead) {
107  func.GetParam(vInitParams.GetBuf());
108  }
109  else {
110  lout << "Random Init parameters" << endl;
111  for (int i = 0; i < m.GetParamNum(); i++)
112  vInitParams[i] = 1.0 * rand() / RAND_MAX - 0.5; // [-0.5, 0.5]
113  }
114 
115 
116  solve.Run(vInitParams.GetBuf());
117 
118  // Finish
119  func.SetParam(solve.m_pdRoot);
120  m.WriteT(cfg_pathModelWrite);
121 
122  SAFE_DELETE(pTrain);
123  SAFE_DELETE(pValid);
124  SAFE_DELETE(pTest);
125 
126  SAFE_DELETE(pv);
127 
128  return 1;
129 }
130 
131 
132 #endif
trf::Vocab Vocab
Definition: hrf-model.h:28
#define SAFE_DELETE(p)
memory release
Definition: wb-vector.h:49
Option opt
char * cfg_pathTrain
char * cfg_pathModelRead
int cfg_nHLayer
hidden-random-field model
Definition: hrf-model.h:98
#define _wbMain
define the main function
Definition: wb-system.h:47
int cfg_nFeatOrder
#define lout_variable(x)
Definition: wb-log.h:179
lout<< "*********************************************"<< endl;lout<< " TRF_SAtrain.exe { by Bin Wang } "<< endl;lout<< "\"<< __DATE__<< "\"<< __TIME__<< "\"<< endl;lout<< "**********************************************"<< endl;omp_set_num_threads(cfg_nThread);lout<< "[OMP] omp_thread = "<< omp_get_max_threads()<< endl;trf::omp_rand(cfg_nThread);Title::SetGlobalTitle(String(cfg_pathModelWrite).FileName());Vocab *pv=new Vocab(cfg_pathVocab);Model m(pv, cfg_nMaxLen);if(cfg_pathModelRead) { m.ReadT(cfg_pathModelRead);} else { m.LoadFromCorpus(cfg_pathTrain, cfg_pathFeatStyle, cfg_nFeatOrder);} CorpusTxt *pTrain=(cfg_pathTrain) ? new CorpusTxt(cfg_pathTrain) :NULL;CorpusTxt *pValid=(cfg_pathValid) ? new CorpusTxt(cfg_pathValid) :NULL;CorpusTxt *pTest=(cfg_pathTest) ? new CorpusTxt(cfg_pathTest) :NULL;SAfunc func;func.m_fdbg.Open(String(cfg_pathModelWrite).FileName()+".sadbg", "wt");func.m_fmean.Open(cfg_pathWriteMean, "wt");func.m_fvar.Open(cfg_pathWriteVar, "wt");func.m_ftrainLL.Open(cfg_pathWriteLLtrain, "wt");func.m_fvallidLL.Open(cfg_pathWriteLLvalid, "wt");func.m_ftestLL.Open(cfg_pathWriteLLtest, "wt");func.m_pathOutputModel=cfg_pathModelWrite;func.Reset(&m, pTrain, pValid, pTest, cfg_nMiniBatch);func.m_fRegL2=cfg_fRegL2;func.m_var_gap=cfg_vGap;func.PrintInfo();SAtrain solve(&func);solve.m_nIterMax=cfg_nIterTotalNum;solve.m_gain_lambda.Reset(cfg_gamma_lambda ? cfg_gamma_lambda :"0,0", cfg_t0);solve.m_gain_zeta.Reset(cfg_gamma_zeta ? cfg_gamma_zeta :"0,0.6", cfg_t0);solve.m_bUpdate_lambda=!cfg_bUnupdateLambda;solve.m_bUpdate_zeta=!cfg_bUnupdateZeta;solve.m_nAvgBeg=cfg_nAvgBeg;solve.m_nPrintPerIter=cfg_nPrintPerIter;solve.m_dir_gap=cfg_dGap;VecUnfold(cfg_strWriteAtIter, solve.m_aWriteAtIter);solve.PrintInfo();bool bInitWeight=(!cfg_pathModelRead)||(cfg_bInitValue &&!cfg_bUnupdateLambda);bool bInitZeta=(!cfg_pathModelRead)||(cfg_bInitValue &&!cfg_bUnupdateZeta);Vec< double > vInitParams(func.GetParamNum())
integer
Definition: wb-option.h:35
char * cfg_pathTest
int cfg_nThread
int cfg_nIterTotalNum
char * cfg_pathValid
pFunc Reset & m
Log lout
the defination is in wb-log.cpp
Definition: wb-log.cpp:22
char * cfg_pathModelWrite
char * cfg_pathVocab
int cfg_nHNode