TRF Language Model
main-sa-train.cpp
Go to the documentation of this file.
1
// You may obtain a copy of the License at
2
//
3
// http://www.apache.org/licenses/LICENSE-2.0
4
//
5
// Unless required by applicable law or agreed to in writing, software
6
// distributed under the License is distributed on an "AS IS" BASIS,
7
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8
// See the License for the specific language governing permissions and
9
// limitations under the License.
10
//
11
// Copyright 2014-2015 Tsinghua University
12
// Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13
//
14
// All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15
// license declaration. Different coding language may use different comment styles.
16
17
#ifndef _MLtrain
18
19
//#include "hrf-sa-train.h"
20
#include "hrf-sams.h"
21
using namespace
hrf
;
22
23
char
*
cfg_pathVocab
= NULL;
24
int
cfg_nHLayer
= 1;
25
int
cfg_nHNode
= 2;
26
int
cfg_nFeatOrder
= 2;
27
char
*
cfg_pathFeatStyle
= NULL;
28
int
cfg_nMaxLen
= 0;
29
30
char
*
cfg_pathTrain
= NULL;
31
char
*
cfg_pathValid
= NULL;
32
char
*
cfg_pathTest
= NULL;
33
34
char
*
cfg_pathModelRead
= NULL;
35
char
*
cfg_pathModelWrite
= NULL;
36
37
int
cfg_nThread
= 1;
38
39
int
cfg_nIterTotalNum
= 1000;
40
int
cfg_nMiniBatch
= 300;
41
int
cfg_t0
= 500;
42
char
*
cfg_gamma_lambda
=
"0,0.8"
;
43
char
*
cfg_gamma_hidden
=
"100,0.8"
;
44
char
*
cfg_gamma_zeta
=
"0,0.6"
;
45
char
*
cfg_gamma_var
=
"0,0.8"
;
46
float
cfg_fMomentum
= 0;
47
float
cfg_var_gap
= 1e-4;
48
float
cfg_dir_gap
= 1;
49
float
cfg_zeta_gap
= 10;
50
bool
cfg_bUpdateLambda
=
false
;
51
bool
cfg_bUpdateZeta
=
false
;
52
int
cfg_nAvgBeg
= 0;
53
54
float
cfg_fRegL2
= 0;
55
56
bool
cfg_bInitValue
=
false
;
57
bool
cfg_bZeroInit
=
false
;
58
int
cfg_nPrintPerIter
= 100;
59
bool
cfg_bUnprintTrain
=
false
;
60
bool
cfg_bUnprintValid
=
false
;
61
bool
cfg_bUnprintTest
=
false
;
62
char
*
cfg_strWriteAtIter
= NULL;
63
64
char
*
cfg_pathWriteMean
= NULL;
65
char
*
cfg_pathWriteVar
= NULL;
66
67
AISConfig
cfg_AIS_for_LL
;
68
69
Option
opt
;
70
71
_wbMain
72
{
73
opt
.Add(
wbOPT_STRING
,
"vocab"
, &
cfg_pathVocab
,
"The vocabulary"
);
74
opt
.Add(
wbOPT_STRING
,
"feat"
, &
cfg_pathFeatStyle
,
"a feature style file. Set this value will disable -order"
);
75
opt
.Add(
wbOPT_INT
,
"order"
, &
cfg_nFeatOrder
,
"the ngram feature order"
);
76
opt
.Add(
wbOPT_INT
,
"len"
, &
cfg_nMaxLen
,
"the maximum length of HRF"
);
77
opt
.Add(
wbOPT_INT
,
"layer"
, &
cfg_nHLayer
,
"the hidden layer of HRF"
);
78
opt
.Add(
wbOPT_INT
,
"node"
, &
cfg_nHNode
,
"the hidden node of each hidden layer of HRF"
);
79
opt
.Add(
wbOPT_STRING
,
"train"
, &
cfg_pathTrain
,
"Training corpus (TXT)"
);
80
opt
.Add(
wbOPT_STRING
,
"valid"
, &
cfg_pathValid
,
"valid corpus (TXT)"
);
81
opt
.Add(
wbOPT_STRING
,
"test"
, &
cfg_pathTest
,
"test corpus (TXT)"
);
82
83
opt
.Add(
wbOPT_STRING
,
"read"
, &
cfg_pathModelRead
,
"Read the init model to train"
);
84
opt
.Add(
wbOPT_STRING
,
"write"
, &
cfg_pathModelWrite
,
"Output model"
);
85
86
opt
.Add(
wbOPT_INT
,
"iter"
, &
cfg_nIterTotalNum
,
"iter total number"
);
87
opt
.Add(
wbOPT_INT
,
"thread"
, &
cfg_nThread
,
"The thread number"
);
88
opt
.Add(
wbOPT_INT
,
"mini-batch"
, &
cfg_nMiniBatch
,
"mini-batch"
);
89
opt
.Add(
wbOPT_INT
,
"t0"
, &
cfg_t0
,
"t0"
);
90
opt
.Add(
wbOPT_STRING
,
"gamma-lambda"
, &
cfg_gamma_lambda
,
"learning rate of lambda"
);
91
opt
.Add(
wbOPT_STRING
,
"gamma-hidden"
, &
cfg_gamma_hidden
,
"learning rate of VHmatrix"
);
92
opt
.Add(
wbOPT_STRING
,
"gamma-zeta"
, &
cfg_gamma_zeta
,
"learning rate of zeta"
);
93
opt
.Add(
wbOPT_STRING
,
"gamma-var"
, &
cfg_gamma_var
,
"learning rate of variance"
);
94
opt
.Add(
wbOPT_FLOAT
,
"momentum"
, &
cfg_fMomentum
,
"the momentum"
);
95
opt
.Add(
wbOPT_TRUE
,
"update-lambda"
, &
cfg_bUpdateLambda
,
"update lambda"
);
96
opt
.Add(
wbOPT_TRUE
,
"update-zeta"
, &
cfg_bUpdateZeta
,
"update zeta"
);
97
opt
.Add(
wbOPT_INT
,
"tavg"
, &
cfg_nAvgBeg
,
">0 then apply averaging"
);
98
opt
.Add(
wbOPT_FLOAT
,
"vgap"
, &
cfg_var_gap
,
"the threshold of variance"
);
99
opt
.Add(
wbOPT_FLOAT
,
"dgap"
, &
cfg_dir_gap
,
"the threshold for parameter update"
);
100
opt
.Add(
wbOPT_FLOAT
,
"zgap"
, &
cfg_zeta_gap
,
"the threshold for zeta update"
);
101
opt
.Add(
wbOPT_FLOAT
,
"L2"
, &
cfg_fRegL2
,
"regularization L2"
);
102
103
opt
.Add(
wbOPT_TRUE
,
"init"
, &
cfg_bInitValue
,
"Re-init the parameters"
);
104
opt
.Add(
wbOPT_TRUE
,
"zero-init"
, &
cfg_bZeroInit
,
"Set the init parameters Zero. Otherwise random init the parameters"
);
105
opt
.Add(
wbOPT_INT
,
"print-per-iter"
, &
cfg_nPrintPerIter
,
"print the LL per iterations"
);
106
opt
.Add(
wbOPT_TRUE
,
"not-print-train"
, &
cfg_bUnprintTrain
,
"donot print LL on training set"
);
107
opt
.Add(
wbOPT_TRUE
,
"not-print-valid"
, &
cfg_bUnprintValid
,
"donot print LL on valid set"
);
108
opt
.Add(
wbOPT_TRUE
,
"not-print-test"
, &
cfg_bUnprintTest
,
"donot print LL on test set"
);
109
opt
.Add(
wbOPT_STRING
,
"write-at-iter"
, &
cfg_strWriteAtIter
,
"write the LL per iteration, such as [1:100:1000]"
);
110
111
opt
.Add(
wbOPT_STRING
,
"write-mean"
, &
cfg_pathWriteMean
,
"write the expecataion on training set"
);
112
opt
.Add(
wbOPT_STRING
,
"write-var"
, &
cfg_pathWriteVar
,
"write the variance on training set"
);
113
114
opt
.Add(
wbOPT_INT
,
"AIS-chain"
, &cfg_AIS_for_LL.
nChain
,
"AIS chain number"
);
115
opt
.Add(
wbOPT_INT
,
"AIS-inter"
, &cfg_AIS_for_LL.
nInter
,
"AIS intermediate distribution number"
);
116
117
opt
.Parse(_argc, _argv);
118
119
lout
<<
"*********************************************"
<< endl;
120
lout
<<
" TRF_SAtrain.exe { by Bin Wang } "
<< endl;
121
lout
<<
"\t"
<< __DATE__ <<
"\t"
<< __TIME__ <<
"\t"
<< endl;
122
lout
<<
"**********************************************"
<< endl;
123
124
omp_set_num_threads(
cfg_nThread
);
125
lout
<<
"[OMP] omp_thread = "
<< omp_get_max_threads() << endl;
126
trf::omp_rand
(
cfg_nThread
);
127
128
/* Load Model and Vocab */
129
Vocab
*pv =
new
Vocab
(
cfg_pathVocab
);
130
Model
m
(pv,
cfg_nHLayer
,
cfg_nHNode
,
cfg_nMaxLen
);
131
if
(
cfg_pathModelRead
) {
132
m.
ReadT
(
cfg_pathModelRead
);
133
}
134
else
{
135
m.
LoadFromCorpus
(
cfg_pathTrain
,
cfg_pathFeatStyle
,
cfg_nFeatOrder
);
136
}
137
lout_variable
(m.
m_hlayer
);
138
lout_variable
(m.
m_hnode
);
139
lout_variable
(m.
GetParamNum
());
140
141
/* Load corpus */
142
trf::CorpusTxt
*pTrain = (
cfg_pathTrain
) ?
new
trf::CorpusTxt
(
cfg_pathTrain
) : NULL;
143
trf::CorpusTxt
*pValid = (
cfg_pathValid
) ?
new
trf::CorpusTxt
(
cfg_pathValid
) : NULL;
144
trf::CorpusTxt
*pTest = (
cfg_pathTest
) ?
new
trf::CorpusTxt
(
cfg_pathTest
) : NULL;
145
146
147
Train *pFunc;
148
if
(
cfg_bUpdateZeta
) {
149
pFunc =
new
SAMSZeta;
150
}
151
else
if
(
cfg_bUpdateLambda
) {
152
pFunc =
new
SALambda;
153
}
154
155
pFunc->OpenTempFile(
cfg_pathModelWrite
);
156
pFunc->
Reset
(&m, pTrain, pValid, pTest);
157
pFunc->m_nMinibatch =
cfg_nMiniBatch
;
158
pFunc->m_nAvgBeg =
cfg_nAvgBeg
;
159
pFunc->m_fRegL2 =
cfg_fRegL2
;
160
pFunc->m_aPrint[0] = !
cfg_bUnprintTrain
;
161
pFunc->m_aPrint[1] = !
cfg_bUnprintValid
;
162
pFunc->m_aPrint[2] = !
cfg_bUnprintTest
;
163
pFunc->m_nPrintPerIter =
cfg_nPrintPerIter
;
164
VecUnfold
(
cfg_strWriteAtIter
, pFunc->m_aWriteAtIter);
165
pFunc->m_nIterMax =
cfg_nIterTotalNum
;
// fix the iteration number
166
pFunc->m_AISConfigForP =
cfg_AIS_for_LL
;
167
pFunc->m_AISConfigForZ =
cfg_AIS_for_LL
;
168
169
if
(
cfg_bUpdateZeta
) {
170
SAMSZeta *p = (SAMSZeta*)pFunc;
171
p->m_zeta_rate.Reset(
cfg_gamma_zeta
,
cfg_t0
);
172
p->m_zeta_gap =
cfg_zeta_gap
;
173
}
174
else
if
(
cfg_bUpdateLambda
) {
175
SALambda *p = (SALambda*)pFunc;
176
p->m_feat_rate.Reset(
cfg_gamma_lambda
,
cfg_t0
);
177
p->m_hidden_rate.Reset(
cfg_gamma_hidden
,
cfg_t0
);
178
p->m_dir_gap =
cfg_dir_gap
;
179
#ifdef _Var
180
p->m_var_rate.Reset(
cfg_gamma_var
,
cfg_t0
);
181
p->m_var_gap =
cfg_var_gap
;
182
#endif
183
}
184
185
/* set initial values */
186
bool
bInit
= (!
cfg_pathModelRead
) ||
cfg_bInitValue
;
187
pFunc->Run(bInit);
188
189
// Finish
190
m.
WriteT
(
cfg_pathModelWrite
);
191
192
SAFE_DELETE
(pTrain);
193
SAFE_DELETE
(pValid);
194
SAFE_DELETE
(pTest);
195
196
SAFE_DELETE
(pv);
197
198
return
1;
199
200
}
201
202
203
#endif
hrf::Model::GetParamNum
int GetParamNum() const
Get the total parameter number.
Definition:
hrf-model.h:130
hrf::Vocab
trf::Vocab Vocab
Definition:
hrf-model.h:28
cfg_t0
int cfg_t0
Definition:
main-sa-train.cpp:41
hrf::AISConfig::nChain
int nChain
chain number
Definition:
hrf-sa-train.h:16
opt
Option opt
Definition:
main-sa-train.cpp:69
trf::CorpusTxt
Definition:
trf-corpus.h:60
cfg_pathTrain
char * cfg_pathTrain
Definition:
main-sa-train.cpp:30
hrf::Model::m_hnode
int m_hnode
the number of hidden nodes
Definition:
hrf-model.h:102
cfg_bUpdateZeta
bool cfg_bUpdateZeta
Definition:
main-sa-train.cpp:51
cfg_pathModelRead
char * cfg_pathModelRead
Definition:
main-sa-train.cpp:34
wb::wbOPT_TRUE
is true if exist
Definition:
wb-option.h:33
cfg_nHLayer
int cfg_nHLayer
Definition:
main-sa-train.cpp:24
cfg_var_gap
float cfg_var_gap
Definition:
main-sa-train.cpp:47
trf::CorpusTxt::Reset
virtual void Reset(const char *pfilename)
Open file and Load the file.
Definition:
trf-corpus.cpp:29
hrf::Model
hidden-random-field model
Definition:
hrf-model.h:98
cfg_nMiniBatch
int cfg_nMiniBatch
Definition:
main-sa-train.cpp:40
hrf::AISConfig::nInter
int nInter
intermediate distribution number
Definition:
hrf-sa-train.h:17
trf::Model::LoadFromCorpus
void LoadFromCorpus(const char *pcorpus, const char *pfeatstyle, int nOrder)
load ngram features from corpus
Definition:
trf-model.cpp:95
SAFE_DELETE
SAFE_DELETE(pTrain)
cfg_zeta_gap
float cfg_zeta_gap
Definition:
main-sa-train.cpp:49
cfg_bUnprintTrain
bool cfg_bUnprintTrain
Definition:
main-sa-train.cpp:59
wb::wbOPT_STRING
string
Definition:
wb-option.h:36
cfg_bZeroInit
bool cfg_bZeroInit
Definition:
main-sa-train.cpp:57
cfg_bUpdateLambda
bool cfg_bUpdateLambda
Definition:
main-sa-train.cpp:50
cfg_nFeatOrder
int cfg_nFeatOrder
Definition:
main-sa-train.cpp:26
lout_variable
#define lout_variable(x)
Definition:
wb-log.h:179
cfg_AIS_for_LL
AISConfig cfg_AIS_for_LL
Definition:
main-sa-train.cpp:67
cfg_bUnprintValid
bool cfg_bUnprintValid
Definition:
main-sa-train.cpp:60
cfg_nAvgBeg
int cfg_nAvgBeg
Definition:
main-sa-train.cpp:52
cfg_nPrintPerIter
int cfg_nPrintPerIter
Definition:
main-sa-train.cpp:58
bInit
bool bInit
Definition:
main-sa-train.cpp:186
_wbMain
_wbMain
Definition:
main-sa-train.cpp:72
hrf::Model::ReadT
void ReadT(const char *pfilename)
Read Model.
Definition:
hrf-model.cpp:149
wb::wbOPT_INT
integer
Definition:
wb-option.h:35
cfg_pathTest
char * cfg_pathTest
Definition:
main-sa-train.cpp:32
cfg_dir_gap
float cfg_dir_gap
Definition:
main-sa-train.cpp:48
cfg_gamma_hidden
char * cfg_gamma_hidden
Definition:
main-sa-train.cpp:43
cfg_gamma_lambda
char * cfg_gamma_lambda
Definition:
main-sa-train.cpp:42
cfg_fRegL2
float cfg_fRegL2
Definition:
main-sa-train.cpp:54
trf::omp_rand
int omp_rand(int thread_num)
Definition:
trf-def.cpp:23
cfg_nThread
int cfg_nThread
Definition:
main-sa-train.cpp:37
trf::Vocab
Definition:
trf-vocab.h:34
cfg_fMomentum
float cfg_fMomentum
Definition:
main-sa-train.cpp:46
cfg_strWriteAtIter
char * cfg_strWriteAtIter
Definition:
main-sa-train.cpp:62
VecUnfold
VecUnfold(cfg_strWriteAtIter, pFunc->m_aWriteAtIter)
wb::wbOPT_FLOAT
float
Definition:
wb-option.h:37
cfg_nIterTotalNum
int cfg_nIterTotalNum
Definition:
main-sa-train.cpp:39
cfg_pathValid
char * cfg_pathValid
Definition:
main-sa-train.cpp:31
hrf::Model::m_hlayer
int m_hlayer
the number of hidden layer
Definition:
hrf-model.h:101
cfg_bUnprintTest
bool cfg_bUnprintTest
Definition:
main-sa-train.cpp:61
cfg_nMaxLen
int cfg_nMaxLen
Definition:
main-sa-train.cpp:28
m
pFunc Reset & m
Definition:
main-sa-train.cpp:156
wb::lout
Log lout
the defination is in wb-log.cpp
Definition:
wb-log.cpp:22
cfg_pathFeatStyle
char * cfg_pathFeatStyle
Definition:
main-sa-train.cpp:27
cfg_pathModelWrite
char * cfg_pathModelWrite
Definition:
main-sa-train.cpp:35
hrf
Definition:
hrf-code-exam.cpp:3
cfg_pathWriteMean
char * cfg_pathWriteMean
Definition:
main-sa-train.cpp:64
cfg_bInitValue
bool cfg_bInitValue
Definition:
main-sa-train.cpp:56
cfg_pathVocab
char * cfg_pathVocab
Definition:
main-sa-train.cpp:23
cfg_gamma_zeta
char * cfg_gamma_zeta
Definition:
main-sa-train.cpp:44
cfg_nHNode
int cfg_nHNode
Definition:
main-sa-train.cpp:25
cfg_pathWriteVar
char * cfg_pathWriteVar
Definition:
main-sa-train.cpp:65
cfg_gamma_var
char * cfg_gamma_var
Definition:
main-sa-train.cpp:45
hrf::Model::WriteT
void WriteT(const char *pfilename)
Write Model.
Definition:
hrf-model.cpp:233
hrf::AISConfig
Definition:
hrf-sa-train.h:13
Users
zhang
spmi
SPMILM
tools
trf
src
HRF
main-sa-train.cpp
Generated by
1.8.12