TRF Language Model
main-TRF.cpp
Go to the documentation of this file.
1 // You may obtain a copy of the License at
2 //
3 // http://www.apache.org/licenses/LICENSE-2.0
4 //
5 // Unless required by applicable law or agreed to in writing, software
6 // distributed under the License is distributed on an "AS IS" BASIS,
7 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 // See the License for the specific language governing permissions and
9 // limitations under the License.
10 //
11 // Copyright 2014-2015 Tsinghua University
12 // Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13 //
14 // All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15 // license declaration. Different coding language may use different comment styles.
16 
17 
18 
19 #include "trf-sa-train.h"
20 //#include "cutrf-model.cuh"
21 using namespace trf;
22 
23 char *cfg_pathVocab = NULL;
24 char *cfg_pathModelRead = NULL;
25 char *cfg_pathModelWrite = NULL;
26 
27 int cfg_nThread = 1;
28 
29 char *cfg_pathTest = NULL;
30 
31 /* lmscore */
32 char *cfg_pathNbest = NULL;
33 char *cfg_writeLmscore = NULL;
34 char *cfg_writeLmscoreDebug = NULL;
35 char *cfg_writeTestID = NULL;
36 
37 /* normalization */
38 char *cfg_norm_method = NULL;
42 int cfg_norm_lenmax = -1;
43 
44 char *cfg_pathLenFile = NULL;
45 
47 /* help */
48 const char *cfg_strHelp = "[Usage] : \n"
49 "Normalizing: \n"
50 " trf -vocab [vocab] -read [model] -write [output model] -norm-method [Exact/AIS]\n"
51 "Calculate log-likelihood:\n"
52 " trf -vocab [vocab] -read [model] -test [txt-id-file]\n"
53 "language model rescoring:\n"
54 " trf -vocab [vocab] -read [model] -nbest [nbest list] -lmscore [output lmscore]\n"
55 "Revise the length distribution pi:\n"
56 " trf -vocab [vocab] -read [model] -write [output moddel] -len-file [a txt-id-file used to summary pi]\n"
57 ;
58 
59 #define lout_exe lout<<"[TRF] "
60 
61 double CalculateLL(Model &m, CorpusTxt *pCorpus, int nCorpusNum, double *pPPL = NULL);
63 void LMRescore(Model &m, const char* pathTest);
64 void ModelNorm(Model &m, const char *type);
65 void ModelRevisePi(Model &m, const char *pathLenFile);
66 
67 _wbMain
68 {
70  opt.Add(wbOPT_STRING, "vocab", &cfg_pathVocab, "The vocabulary");
71  opt.Add(wbOPT_STRING, "read", &cfg_pathModelRead, "Read the init model to train");
72  opt.Add(wbOPT_STRING, "write", &cfg_pathModelWrite, "output the normalizaed model");
73  opt.Add(wbOPT_INT, "thread", &cfg_nThread, "The thread number");
74 
75  opt.Add(wbOPT_STRING, "test", &cfg_pathTest, "test corpus (TXT)");
76 
77  opt.Add(wbOPT_STRING, "nbest", &cfg_pathNbest, "nbest list (kaldi output)");
78  opt.Add(wbOPT_STRING, "lmscore", &cfg_writeLmscore, "[LMrescore] output the lmsocre");
79  opt.Add(wbOPT_STRING, "lmscore-debug", &cfg_writeLmscoreDebug, "[LMrescore] output the lmscore of each word for word-level combination");
80  opt.Add(wbOPT_STRING, "lmscore-test-id", &cfg_writeTestID, "[LMrescore] output the vocab-id of test file");
81 
82  opt.Add(wbOPT_STRING, "norm-method", &cfg_norm_method, "[Norm] method: Exact or AIS");
83  opt.Add(wbOPT_INT, "AIS-chain", &cfg_nAIS_chain_num, "[AIS] the chain number");
84  opt.Add(wbOPT_INT, "AIS-inter", &cfg_nAIS_inter_num, "[AIS] the intermediate distribution number");
85  opt.Add(wbOPT_INT, "norm-len-min", &cfg_norm_lenmin, "[Norm] min-length");
86  opt.Add(wbOPT_INT, "norm-len-max", &cfg_norm_lenmax, "[Norm] max-length");
87 
88  opt.Add(wbOPT_STRING, "len-file", &cfg_pathLenFile, "[Revise pi] a txt-id-file used to summary pi");
89 
90  opt.Parse(_argc, _argv);
91 
92  lout << "*********************************************" << endl;
93  lout << " TRF.exe " << endl;
94  lout << "\t" << __DATE__ << "\t" << __TIME__ << "\t" << endl;
95  lout << "**********************************************" << endl;
96 
97  omp_set_num_threads(cfg_nThread);
98  lout << "[OMP] omp_thread = " << omp_get_max_threads() << endl;
100 
102  Vocab v(cfg_pathVocab);
103  Model m(&v);
104  lout_exe << "Read model: " << cfg_pathModelRead << endl;
106 
107  /* Operation 0: normalization */
108  if (cfg_norm_method) {
110  }
111 
112  /* Operation 3: revise pi*/
113  if (cfg_pathLenFile) {
115  }
116 
117  /* Operation 1: calculate LL */
118  if (cfg_pathTest) {
119  CorpusTxt *p = new CorpusTxt(cfg_pathTest);
120  double dPPL;
121  double dLL = CalculateLL(m, p, p->GetNum(), &dPPL);
122  lout_exe << "calculate LL of : " << cfg_pathTest << endl;
123  lout_exe << "-LL = " << -dLL << endl;
124  lout_exe << "PPL = " << dPPL << endl;
125  SAFE_DELETE(p);
126  }
127 
128  /* Operation 2: lmscore */
129  if (cfg_pathNbest) {
131  }
132 
133 
134  /* write model */
135  if (cfg_pathModelWrite) {
136  lout_exe << "Write model: " << cfg_pathModelWrite << endl;
138  }
139 
140  return 1;
141 }
142 
143 double CalculateLL(Model &m, CorpusTxt *pCorpus, int nCorpusNum, double *pPPL /*= NULL*/)
144 {
145  Array<double> aLL(omp_get_max_threads());
146  aLL.Fill(0);
147 
148  Array<int> aWords(omp_get_max_threads());
149  aWords.Fill(0);
150  Array<int> aSents(omp_get_max_threads());
151  aSents.Fill(0);
152 
154  lout.Progress(0, true, nCorpusNum - 1, "omp GetLL");
155 #pragma omp parallel for firstprivate(aSeq)
156  for (int i = 0; i < nCorpusNum; i++) {
157  pCorpus->GetSeq(i, aSeq);
158 
159  Seq seq;
160  seq.Set(aSeq, m.m_pVocab);
161  LogP logprob = m.GetLogProb(seq);
162 
163  aLL[omp_get_thread_num()] += logprob;
164  aWords[omp_get_thread_num()] += aSeq.GetNum();
165  aSents[omp_get_thread_num()] += 1;
166 
167 #pragma omp critical
168  lout.Progress();
169  }
170 
171  double dLL = aLL.Sum() / nCorpusNum;
172  int nSent = aSents.Sum();
173  int nWord = aWords.Sum();
174  lout_variable(nSent);
175  lout_variable(nWord);
176  if (pPPL) *pPPL = exp(-dLL * nSent / (nSent + nWord));
177  return dLL;
178 }
179 
181 {
182  for (int i = 0; i < aStrs.GetNum(); i++) {
183  String wstr = aStrs[i];
184  VocabID *pvid = vocabhash.Find(wstr.Toupper());
185  if (pvid == NULL) { // cannot find the word, then find <UNK>
186  // as word has been saved into hash with uppor style,
187  // then we need to find <UNK>, not <unk>.
188  pvid = vocabhash.Find("<UNK>");
189  if (!pvid) {
190  lout_error("Can't find a vocab-id of " << wstr.GetBuffer());
191  }
192  }
193  aIDs[i] = *pvid;
194  }
195 }
196 
197 void LMRescore(Model &m, const char* pathTest)
198 {
199  Vocab *pV = m.m_pVocab;
200 
202  LHash<const char*, VocabID> vocabhash;
203  bool bFound;
204  for (int i = 0; i < pV->GetSize(); i++) {
205  int *pVID = vocabhash.Insert(String(pV->GetWordStr(i)).Toupper(), bFound);
206  if (bFound) {
207  lout_exe << "Find words with same name but different id! (str="
208  << pV->GetWordStr(i) << " id=" << i << ")" << endl;
209  exit(1);
210  }
211  *pVID = i;
212  }
213 
215  lout_exe << "Rescoring: " << pathTest << " ..." << endl;
216 
217  File fLmscore(cfg_writeLmscore, "wt");
218  File fTestid(cfg_writeTestID, "wt");
219  File file(pathTest, "rt");
220  char *pLine;
221  while (pLine = file.GetLine(true)) {
222  String curLabel = strtok(pLine, " \t\n");
223  String curSent = strtok(NULL, "\n");
224 
225  Array<String> aWordStrs;
226  curSent.Split(aWordStrs, " \t\n");
227 
228  Array<VocabID> aWordIDs;
229  WordStr2ID(aWordIDs, aWordStrs, vocabhash);
230 
231  Seq seq;
232  seq.Set(aWordIDs, pV);
233  LogP curLmscore = -m.GetLogProb(seq);
234 
235  /* output lmscore */
236  fLmscore.Print("%s %lf\n", curLabel.GetBuffer(), curLmscore);
237  /* output test-id */
238  if (fTestid.Good()) {
239  fTestid.Print("%s\t", curLabel.GetBuffer());
240  fTestid.PrintArray("%d", aWordIDs.GetBuffer(), aWordIDs.GetNum());
241  }
242  }
243 }
244 
245 void ModelNorm(Model &m, const char *type)
246 {
247  String strType = type;
248  strType.Tolower();
249  if (strType == "exact") {
250  lout_exe << "Exact Normalization..." << endl;
251  m.ExactNormalize();
252  }
253  else if (strType == "ais") {
255  if (cfg_nAIS_chain_num <= 0) {
256  lout_exe << "[Input] AIS chain number = ";
257  cin >> cfg_nAIS_chain_num;
258  }
259  if (cfg_nAIS_inter_num <= 0) {
260  lout_exe << "[Input] AIS intermediate distribution number = ";
261  cin >> cfg_nAIS_inter_num;
262  }
263  lout_exe << "AIS normalization..." << endl;
266 
267  srand(time(NULL));
268 
270  cfg_norm_lenmax = (cfg_norm_lenmax == -1) ? m.GetMaxLen() : cfg_norm_lenmax;
271 // for (int nLen = cfg_norm_lenmin; nLen <= cfg_norm_lenmax; nLen++)
272 // {
273 // lout_exe << "nLen = " << nLen << "/" << m.GetMaxLen() << ": ";
274 // m.AISNormalize(nLen, cfg_nAIS_chain_num, cfg_nAIS_inter_num);
275 // //cutrf::cudaModelAIS(m, nLen, cfg_nAIS_chain_num, cfg_nAIS_inter_num);
276 // lout << endl;
277 // }
279  }
280  else {
281  lout_error("Unknown method: " << type);
282  }
283 }
284 
285 void ModelRevisePi(Model &m, const char *pathLenFile)
286 {
287  lout << "Revise the length distribution pi..." << endl;
288  int nMaxLen = m.GetMaxLen();
289  Vec<Prob> vLen(nMaxLen+1);
290  vLen.Fill(0);
291 
292  File file(pathLenFile, "rt");
293  int nLine = 0;
294  char *pLine;
295  while (pLine = file.GetLine()) {
296  nLine++;
297  int nLen = 0;
298  char *p = strtok(pLine, " \t\n");
299  while (p) {
300  nLen++;
301  p = strtok(NULL, " \t\n");
302  }
303  nLen = min(nLen, nMaxLen);
304  vLen[nLen] += 1;
305  }
306  vLen /= nLine;
307 
308  m.SetPi(vLen.GetBuf());
309 }
double dLL
Definition: main-TRF.cpp:171
#define SAFE_DELETE(p)
memory release
Definition: wb-vector.h:49
a dynamic string class
Definition: wb-string.h:53
void WordStr2ID(Array< VocabID > &aIDs, Array< String > &aStrs, LHash< const char *, VocabID > &vocabhash)
Definition: main-TRF.cpp:180
void ReadT(const char *pfilename)
Read Model.
Definition: trf-model.cpp:114
const char * cfg_strHelp
Definition: main-TRF.cpp:48
void ModelRevisePi(Model &m, const char *pathLenFile)
Definition: main-TRF.cpp:285
char * cfg_writeLmscore
Definition: main-TRF.cpp:33
_wbMain
Definition: main-TRF.cpp:68
void LMRescore(Model &m, const char *pathTest)
Definition: main-TRF.cpp:197
int VocabID
Definition: trf-vocab.h:23
#define lout_error(x)
Definition: wb-log.h:183
void SetPi(Prob *pPi)
Set the pi.
Definition: trf-model.cpp:70
void Parse(const char *plabel, const char *pvalue)
parse a single option, "pvalue" can be NULL
Definition: wb-option.cpp:80
char * Tolower()
to lower
Definition: wb-string.cpp:177
void Split(Array< String > &aStrs, const char *delimiter)
split to string array. Using strtok().
Definition: wb-string.cpp:149
LogP GetLogProb(Seq &seq, bool bNorm=true)
calculate the probability
Definition: trf-model.cpp:74
double CalculateLL(Model &m, CorpusTxt *pCorpus, int nCorpusNum, double *pPPL=NULL)
lout_variable(nSent)
lout<< "*********************************************"<< endl;lout<< " TRF.exe "<< endl;lout<< "\"<< __DATE__<< "\"<< __TIME__<< "\"<< endl;lout<< "**********************************************"<< endl;omp_set_num_threads(cfg_nThread);lout<< "[OMP] omp_thread = "<< omp_get_max_threads()<< endl;omp_rand(cfg_nThread);Vocab v(cfg_pathVocab);Model m(&v);lout_exe<< "Read model: "<< cfg_pathModelRead<< endl;m.ReadT(cfg_pathModelRead);if(cfg_norm_method) { ModelNorm(m, cfg_norm_method);} if(cfg_pathLenFile) { ModelRevisePi(m, cfg_pathLenFile);} if(cfg_pathTest) { CorpusTxt *p=new CorpusTxt(cfg_pathTest);double dPPL;double dLL=CalculateLL(m, p, p->GetNum(), &dPPL);lout_exe<< "calculate LL of : "<< cfg_pathTest<< endl;lout_exe<< "-LL = "<< -dLL<< endl;lout_exe<< "PPL = "<< dPPL<< endl;SAFE_DELETE(p);} if(cfg_pathNbest) { LMRescore(m, cfg_pathNbest);} if(cfg_pathModelWrite) { lout_exe<< "Write model: "<< cfg_pathModelWrite<< endl;m.WriteT(cfg_pathModelWrite);} return 1;}double CalculateLL(Model &m, CorpusTxt *pCorpus, int nCorpusNum, double *pPPL){ Array< double > aLL(omp_get_max_threads())
Definition: main-TRF.cpp:145
Array< int > aWords(omp_get_max_threads())
int cfg_nThread
Definition: main-TRF.cpp:27
double LogP
Definition: trf-def.h:27
void ModelNorm(Model &m, const char *type)
Definition: main-TRF.cpp:245
T * GetBuffer(int i=0) const
get the buffer pointer
Definition: wb-vector.h:97
int cfg_nAIS_inter_num
Definition: main-TRF.cpp:40
virtual void Print(const char *p_pMessage,...)
print
Definition: wb-file.cpp:115
T Sum()
summate all the values in the array
Definition: wb-vector.h:308
define a sequence including the word sequence and class sequence
Definition: trf-feature.h:41
char * cfg_pathLenFile
Definition: main-TRF.cpp:44
int cfg_nAIS_chain_num
Definition: main-TRF.cpp:39
char * cfg_pathModelWrite
Definition: main-TRF.cpp:25
integer
Definition: wb-option.h:35
TRF model.
Definition: trf-model.h:51
string m_strOtherHelp
extra help information, which will be output in PrintUsage
Definition: wb-option.h:58
void Fill(T v)
Definition: wb-mat.h:279
file class.
Definition: wb-file.h:94
#define lout_exe
Definition: main-TRF.cpp:59
void Set(Array< int > &aInt, Vocab *pv)
transform the word sequence (form file) to Seq
Definition: trf-feature.cpp:22
char * cfg_pathTest
Definition: main-TRF.cpp:29
T * GetBuf() const
Definition: wb-mat.h:68
virtual char * GetLine(bool bPrecent=false)
Read a line into the buffer.
Definition: wb-file.cpp:47
void Progress(long long n=-1, bool bInit=false, long long total=100, const char *head="")
progress bar
Definition: wb-log.cpp:146
int omp_rand(int thread_num)
Definition: trf-def.cpp:23
Array< VocabID > aSeq
Definition: main-TRF.cpp:153
Vocab * m_pVocab
Definition: trf-model.h:62
Array< int > aSents(omp_get_max_threads())
DataT * Insert(KeyT key, bool &bFound)
Insert a value.
Definition: wb-lhash.h:408
char * cfg_pathVocab
Definition: main-TRF.cpp:23
char * cfg_writeLmscoreDebug
Definition: main-TRF.cpp:34
int GetNum() const
Get Array number.
Definition: wb-vector.h:240
void Add(ValueType t, const char *pLabel, void *pAddress, const char *pDocMsg=NULL)
Add a option.
Definition: wb-option.cpp:35
int cfg_norm_lenmin
Definition: main-TRF.cpp:41
int GetSize()
get the vocab size, i.e. the word number
Definition: trf-vocab.h:47
Option opt
Definition: main-TRF.cpp:46
pFunc Reset & m
Log lout
the defination is in wb-log.cpp
Definition: wb-log.cpp:22
virtual bool GetSeq(int nLine, Array< VocabID > &aSeq)
get the sequence in nLine
Definition: trf-corpus.cpp:57
void Fill(T m)
set all the values to m
Definition: wb-vector.h:139
bool Good() const
return if the file is accessible.
Definition: wb-file.h:145
const char * GetWordStr(int id)
get word string
Definition: trf-vocab.h:51
void PrintArray(const char *pformat, TYPE *pbuf, int num)
print a array into file
Definition: wb-file.h:148
virtual double ExactNormalize(int nLen)
[exact] Exact Normalization, return the logz of given length
Definition: trf-model.cpp:213
char * GetBuffer() const
get buffer
Definition: wb-string.h:74
int cfg_norm_lenmax
Definition: main-TRF.cpp:42
void WriteT(const char *pfilename)
Write Model.
Definition: trf-model.cpp:158
char * cfg_pathModelRead
Definition: main-TRF.cpp:24
int nSent
Definition: main-TRF.cpp:172
Definition: trf-alg.cpp:20
int nWord
Definition: main-TRF.cpp:173
int GetMaxLen() const
Get max-len.
Definition: trf-model.h:100
char * Toupper()
to upper
Definition: wb-string.cpp:170
a linear hash table
Definition: wb-lhash.h:41
Get the option from command line or command files.
Definition: wb-option.h:54
char * cfg_norm_method
Definition: main-TRF.cpp:38
virtual int GetNum() const
get the seq number
Definition: trf-corpus.h:47
char * cfg_writeTestID
Definition: main-TRF.cpp:35
LogP AISNormalize(int nLen, int nChain, int nInter)
perform AIS to calculate the normalization constants, return the logz of given length ...
Definition: trf-model.cpp:550
DataT * Find(KeyT key, bool &bFound)
Find a value.
Definition: wb-lhash.h:392
char * cfg_pathNbest
Definition: main-TRF.cpp:32