|
TRF Language Model
|
#include "hrf-sams.h"Go to the source code of this file.
Functions | |
| opt | Add (wbOPT_STRING, "feat", &cfg_pathFeatStyle, "a feature style file. Set this value will disable -order") |
| opt | Add (wbOPT_INT, "order", &cfg_nFeatOrder, "the ngram feature order") |
| opt | Add (wbOPT_INT, "len", &cfg_nMaxLen, "the maximum length of HRF") |
| opt | Add (wbOPT_INT, "layer", &cfg_nHLayer, "the hidden layer of HRF") |
| opt | Add (wbOPT_INT, "node", &cfg_nHNode, "the hidden node of each hidden layer of HRF") |
| opt | Add (wbOPT_STRING, "train", &cfg_pathTrain, "Training corpus (TXT)") |
| opt | Add (wbOPT_STRING, "valid", &cfg_pathValid, "valid corpus (TXT)") |
| opt | Add (wbOPT_STRING, "test", &cfg_pathTest, "test corpus (TXT)") |
| opt | Add (wbOPT_STRING, "read", &cfg_pathModelRead, "Read the init model to train") |
| opt | Add (wbOPT_STRING, "write", &cfg_pathModelWrite, "Output model") |
| opt | Add (wbOPT_INT, "iter", &cfg_nIterTotalNum, "iter total number") |
| opt | Add (wbOPT_INT, "thread", &cfg_nThread, "The thread number") |
| opt | Add (wbOPT_INT, "mini-batch", &cfg_nMiniBatch, "mini-batch") |
| opt | Add (wbOPT_INT, "t0", &cfg_t0, "t0") |
| opt | Add (wbOPT_STRING, "gamma-lambda", &cfg_gamma_lambda, "learning rate of lambda") |
| opt | Add (wbOPT_STRING, "gamma-hidden", &cfg_gamma_hidden, "learning rate of VHmatrix") |
| opt | Add (wbOPT_STRING, "gamma-zeta", &cfg_gamma_zeta, "learning rate of zeta") |
| opt | Add (wbOPT_STRING, "gamma-var", &cfg_gamma_var, "learning rate of variance") |
| opt | Add (wbOPT_FLOAT, "momentum", &cfg_fMomentum, "the momentum") |
| opt | Add (wbOPT_TRUE, "update-lambda", &cfg_bUpdateLambda, "update lambda") |
| opt | Add (wbOPT_TRUE, "update-zeta", &cfg_bUpdateZeta, "update zeta") |
| opt | Add (wbOPT_INT, "tavg", &cfg_nAvgBeg, ">0 then apply averaging") |
| opt | Add (wbOPT_FLOAT, "vgap", &cfg_var_gap, "the threshold of variance") |
| opt | Add (wbOPT_FLOAT, "dgap", &cfg_dir_gap, "the threshold for parameter update") |
| opt | Add (wbOPT_FLOAT, "zgap", &cfg_zeta_gap, "the threshold for zeta update") |
| opt | Add (wbOPT_FLOAT, "L2", &cfg_fRegL2, "regularization L2") |
| opt | Add (wbOPT_TRUE, "init", &cfg_bInitValue, "Re-init the parameters") |
| opt | Add (wbOPT_TRUE, "zero-init", &cfg_bZeroInit, "Set the init parameters Zero. Otherwise random init the parameters") |
| opt | Add (wbOPT_INT, "print-per-iter", &cfg_nPrintPerIter, "print the LL per iterations") |
| opt | Add (wbOPT_TRUE, "not-print-train", &cfg_bUnprintTrain, "donot print LL on training set") |
| opt | Add (wbOPT_TRUE, "not-print-valid", &cfg_bUnprintValid, "donot print LL on valid set") |
| opt | Add (wbOPT_TRUE, "not-print-test", &cfg_bUnprintTest, "donot print LL on test set") |
| opt | Add (wbOPT_STRING, "write-at-iter", &cfg_strWriteAtIter, "write the LL per iteration, such as [1:100:1000]") |
| opt | Add (wbOPT_STRING, "write-mean", &cfg_pathWriteMean, "write the expecataion on training set") |
| opt | Add (wbOPT_STRING, "write-var", &cfg_pathWriteVar, "write the variance on training set") |
| opt | Add (wbOPT_INT, "AIS-chain", &cfg_AIS_for_LL.nChain, "AIS chain number") |
| opt | Add (wbOPT_INT, "AIS-inter", &cfg_AIS_for_LL.nInter, "AIS intermediate distribution number") |
| opt | Parse (_argc, _argv) |
| lout<< "*********************************************"<< endl;lout<< " TRF_SAtrain.exe { by Bin Wang } "<< endl;lout<< "\"<< __DATE__<< "\"<< __TIME__<< "\"<< endl;lout<< "**********************************************"<< endl;omp_set_num_threads(cfg_nThread);lout<< "[OMP] omp_thread = "<< omp_get_max_threads()<< endl;trf::omp_rand(cfg_nThread);Vocab *pv=new Vocab(cfg_pathVocab);Model m(pv, cfg_nHLayer, cfg_nHNode, cfg_nMaxLen);if(cfg_pathModelRead) { m.ReadT(cfg_pathModelRead);} else { m.LoadFromCorpus(cfg_pathTrain, cfg_pathFeatStyle, cfg_nFeatOrder);} lout_variable(m.m_hlayer);lout_variable(m.m_hnode);lout_variable(m.GetParamNum());trf::CorpusTxt *pTrain=(cfg_pathTrain) ? new trf::CorpusTxt(cfg_pathTrain) :NULL;trf::CorpusTxt *pValid=(cfg_pathValid) ? new trf::CorpusTxt(cfg_pathValid) :NULL;trf::CorpusTxt *pTest=(cfg_pathTest) ? new trf::CorpusTxt(cfg_pathTest) :NULL;Train *pFunc;if(cfg_bUpdateZeta) { pFunc=new SAMSZeta;} else if(cfg_bUpdateLambda) { pFunc=new SALambda;} pFunc-> | OpenTempFile (cfg_pathModelWrite) |
| VecUnfold (cfg_strWriteAtIter, pFunc->m_aWriteAtIter) | |
| if (cfg_bUpdateZeta) | |
| else | if (cfg_bUpdateLambda) |
| pFunc | Run (bInit) |
| m | WriteT (cfg_pathModelWrite) |
| SAFE_DELETE (pTrain) | |
| SAFE_DELETE (pValid) | |
| SAFE_DELETE (pTest) | |
| SAFE_DELETE (pv) | |
| opt Add | ( | wbOPT_STRING | , |
| "feat" | , | ||
| & | cfg_pathFeatStyle, | ||
| "a feature style file. Set this value will disable -order" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "order" | , | ||
| & | cfg_nFeatOrder, | ||
| "the ngram feature order" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "len" | , | ||
| & | cfg_nMaxLen, | ||
| "the maximum length of HRF" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "layer" | , | ||
| & | cfg_nHLayer, | ||
| "the hidden layer of HRF" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "node" | , | ||
| & | cfg_nHNode, | ||
| "the hidden node of each hidden layer of HRF" | |||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "train" | , | ||
| & | cfg_pathTrain, | ||
| "Training corpus (TXT)" | |||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "valid" | , | ||
| & | cfg_pathValid, | ||
| "valid corpus (TXT)" | |||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "test" | , | ||
| & | cfg_pathTest, | ||
| "test corpus (TXT)" | |||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "read" | , | ||
| & | cfg_pathModelRead, | ||
| "Read the init model to train" | |||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "write" | , | ||
| & | cfg_pathModelWrite, | ||
| "Output model" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "iter" | , | ||
| & | cfg_nIterTotalNum, | ||
| "iter total number" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "thread" | , | ||
| & | cfg_nThread, | ||
| "The thread number" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "mini-batch" | , | ||
| & | cfg_nMiniBatch, | ||
| "mini-batch" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "t0" | , | ||
| & | cfg_t0, | ||
| "t0" | |||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "gamma-lambda" | , | ||
| & | cfg_gamma_lambda, | ||
| "learning rate of lambda" | |||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "gamma-hidden" | , | ||
| & | cfg_gamma_hidden, | ||
| "learning rate of VHmatrix" | |||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "gamma-zeta" | , | ||
| & | cfg_gamma_zeta, | ||
| "learning rate of zeta" | |||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "gamma-var" | , | ||
| & | cfg_gamma_var, | ||
| "learning rate of variance" | |||
| ) |
| opt Add | ( | wbOPT_FLOAT | , |
| "momentum" | , | ||
| & | cfg_fMomentum, | ||
| "the momentum" | |||
| ) |
| opt Add | ( | wbOPT_TRUE | , |
| "update-lambda" | , | ||
| & | cfg_bUpdateLambda, | ||
| "update lambda" | |||
| ) |
| opt Add | ( | wbOPT_TRUE | , |
| "update-zeta" | , | ||
| & | cfg_bUpdateZeta, | ||
| "update zeta" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "tavg" | , | ||
| & | cfg_nAvgBeg, | ||
| " | , | ||
| 0 then apply averaging" | |||
| ) |
| opt Add | ( | wbOPT_FLOAT | , |
| "vgap" | , | ||
| & | cfg_var_gap, | ||
| "the threshold of variance" | |||
| ) |
| opt Add | ( | wbOPT_FLOAT | , |
| "L2" | , | ||
| & | cfg_fRegL2, | ||
| "regularization L2" | |||
| ) |
| opt Add | ( | wbOPT_TRUE | , |
| "init" | , | ||
| & | cfg_bInitValue, | ||
| "Re-init the parameters" | |||
| ) |
| opt Add | ( | wbOPT_TRUE | , |
| "zero-init" | , | ||
| & | cfg_bZeroInit, | ||
| "Set the init parameters Zero. Otherwise random init the parameters" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "print-per-iter" | , | ||
| & | cfg_nPrintPerIter, | ||
| "print the LL per iterations" | |||
| ) |
| opt Add | ( | wbOPT_TRUE | , |
| "not-print-train" | , | ||
| & | cfg_bUnprintTrain, | ||
| "donot print LL on training set" | |||
| ) |
| opt Add | ( | wbOPT_TRUE | , |
| "not-print-valid" | , | ||
| & | cfg_bUnprintValid, | ||
| "donot print LL on valid set" | |||
| ) |
| opt Add | ( | wbOPT_TRUE | , |
| "not-print-test" | , | ||
| & | cfg_bUnprintTest, | ||
| "donot print LL on test set" | |||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "write-at-iter" | , | ||
| & | cfg_strWriteAtIter, | ||
| "write the LL per | iteration, | ||
| such as " | [1:100:1000] | ||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "write-mean" | , | ||
| & | cfg_pathWriteMean, | ||
| "write the expecataion on training set" | |||
| ) |
| opt Add | ( | wbOPT_STRING | , |
| "write-var" | , | ||
| & | cfg_pathWriteVar, | ||
| "write the variance on training set" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "AIS-chain" | , | ||
| &cfg_AIS_for_LL. | nChain, | ||
| "AIS chain number" | |||
| ) |
| opt Add | ( | wbOPT_INT | , |
| "AIS-inter" | , | ||
| &cfg_AIS_for_LL. | nInter, | ||
| "AIS intermediate distribution number" | |||
| ) |
| if | ( | cfg_bUpdateZeta | ) |
Definition at line 169 of file main-sa-train.cpp.
| else if | ( | cfg_bUpdateLambda | ) |
Definition at line 174 of file main-sa-train.cpp.
| lout<< "*********************************************" << endl; lout << " TRF_SAtrain.exe { by Bin Wang } " << endl; lout << "\t" << __DATE__ << "\t" << __TIME__ << "\t" << endl; lout << "**********************************************" << endl; omp_set_num_threads(cfg_nThread); lout << "[OMP] omp_thread = " << omp_get_max_threads() << endl; trf::omp_rand(cfg_nThread); Vocab *pv = new Vocab(cfg_pathVocab); Model m(pv, cfg_nHLayer, cfg_nHNode, cfg_nMaxLen); if (cfg_pathModelRead) { m.ReadT(cfg_pathModelRead); } else { m.LoadFromCorpus(cfg_pathTrain, cfg_pathFeatStyle, cfg_nFeatOrder); } lout_variable(m.m_hlayer); lout_variable(m.m_hnode); lout_variable(m.GetParamNum()); trf::CorpusTxt *pTrain = (cfg_pathTrain) ? new trf::CorpusTxt(cfg_pathTrain) : NULL; trf::CorpusTxt *pValid = (cfg_pathValid) ? new trf::CorpusTxt(cfg_pathValid) : NULL; trf::CorpusTxt *pTest = (cfg_pathTest) ? new trf::CorpusTxt(cfg_pathTest) : NULL; Train *pFunc; if (cfg_bUpdateZeta) { pFunc = new SAMSZeta; } else if (cfg_bUpdateLambda) { pFunc = new SALambda; } pFunc-> OpenTempFile | ( | cfg_pathModelWrite | ) |
| opt Parse | ( | _argc | , |
| _argv | |||
| ) |
| pFunc Run | ( | bInit | ) |
| SAFE_DELETE | ( | pTrain | ) |
| SAFE_DELETE | ( | pValid | ) |
| SAFE_DELETE | ( | pTest | ) |
| SAFE_DELETE | ( | pv | ) |
| VecUnfold | ( | cfg_strWriteAtIter | , |
| pFunc-> | m_aWriteAtIter | ||
| ) |
| m WriteT | ( | cfg_pathModelWrite | ) |
| _wbMain |
Definition at line 72 of file main-sa-train.cpp.
| bool bInit = (!cfg_pathModelRead) || cfg_bInitValue |
Definition at line 186 of file main-sa-train.cpp.
| AISConfig cfg_AIS_for_LL |
Definition at line 67 of file main-sa-train.cpp.
| bool cfg_bInitValue = false |
Definition at line 56 of file main-sa-train.cpp.
| bool cfg_bUnprintTest = false |
Definition at line 61 of file main-sa-train.cpp.
| bool cfg_bUnprintTrain = false |
Definition at line 59 of file main-sa-train.cpp.
| bool cfg_bUnprintValid = false |
Definition at line 60 of file main-sa-train.cpp.
| bool cfg_bUpdateLambda = false |
Definition at line 50 of file main-sa-train.cpp.
| bool cfg_bUpdateZeta = false |
Definition at line 51 of file main-sa-train.cpp.
| bool cfg_bZeroInit = false |
Definition at line 57 of file main-sa-train.cpp.
| float cfg_dir_gap = 1 |
Definition at line 48 of file main-sa-train.cpp.
| float cfg_fMomentum = 0 |
Definition at line 46 of file main-sa-train.cpp.
| float cfg_fRegL2 = 0 |
Definition at line 54 of file main-sa-train.cpp.
| char* cfg_gamma_hidden = "100,0.8" |
Definition at line 43 of file main-sa-train.cpp.
| char* cfg_gamma_lambda = "0,0.8" |
Definition at line 42 of file main-sa-train.cpp.
| char* cfg_gamma_var = "0,0.8" |
Definition at line 45 of file main-sa-train.cpp.
| char* cfg_gamma_zeta = "0,0.6" |
Definition at line 44 of file main-sa-train.cpp.
| int cfg_nAvgBeg = 0 |
Definition at line 52 of file main-sa-train.cpp.
| int cfg_nFeatOrder = 2 |
Definition at line 26 of file main-sa-train.cpp.
| int cfg_nHLayer = 1 |
Definition at line 24 of file main-sa-train.cpp.
| int cfg_nHNode = 2 |
Definition at line 25 of file main-sa-train.cpp.
| int cfg_nIterTotalNum = 1000 |
Definition at line 39 of file main-sa-train.cpp.
| int cfg_nMaxLen = 0 |
Definition at line 28 of file main-sa-train.cpp.
| int cfg_nMiniBatch = 300 |
Definition at line 40 of file main-sa-train.cpp.
| int cfg_nPrintPerIter = 100 |
Definition at line 58 of file main-sa-train.cpp.
| int cfg_nThread = 1 |
Definition at line 37 of file main-sa-train.cpp.
| char* cfg_pathFeatStyle = NULL |
Definition at line 27 of file main-sa-train.cpp.
| char* cfg_pathModelRead = NULL |
Definition at line 34 of file main-sa-train.cpp.
| char* cfg_pathModelWrite = NULL |
Definition at line 35 of file main-sa-train.cpp.
| char* cfg_pathTest = NULL |
Definition at line 32 of file main-sa-train.cpp.
| char* cfg_pathTrain = NULL |
Definition at line 30 of file main-sa-train.cpp.
| char* cfg_pathValid = NULL |
Definition at line 31 of file main-sa-train.cpp.
| char* cfg_pathVocab = NULL |
Definition at line 23 of file main-sa-train.cpp.
| char* cfg_pathWriteMean = NULL |
Definition at line 64 of file main-sa-train.cpp.
| char* cfg_pathWriteVar = NULL |
Definition at line 65 of file main-sa-train.cpp.
| char* cfg_strWriteAtIter = NULL |
Definition at line 62 of file main-sa-train.cpp.
| int cfg_t0 = 500 |
Definition at line 41 of file main-sa-train.cpp.
| float cfg_var_gap = 1e-4 |
Definition at line 47 of file main-sa-train.cpp.
| float cfg_zeta_gap = 10 |
Definition at line 49 of file main-sa-train.cpp.
| pFunc Reset& m |
Definition at line 156 of file main-sa-train.cpp.
| pFunc m_AISConfigForP = cfg_AIS_for_LL |
Definition at line 166 of file main-sa-train.cpp.
| pFunc m_AISConfigForZ = cfg_AIS_for_LL |
Definition at line 167 of file main-sa-train.cpp.
| pFunc m_aPrint = !cfg_bUnprintTrain |
Definition at line 160 of file main-sa-train.cpp.
| pFunc m_fRegL2 = cfg_fRegL2 |
Definition at line 159 of file main-sa-train.cpp.
| pFunc m_nAvgBeg = cfg_nAvgBeg |
Definition at line 158 of file main-sa-train.cpp.
| pFunc m_nIterMax = cfg_nIterTotalNum |
Definition at line 165 of file main-sa-train.cpp.
| pFunc m_nMinibatch = cfg_nMiniBatch |
Definition at line 157 of file main-sa-train.cpp.
| pFunc m_nPrintPerIter = cfg_nPrintPerIter |
Definition at line 163 of file main-sa-train.cpp.
| Option opt |
Definition at line 69 of file main-sa-train.cpp.
| return |
Definition at line 198 of file main-sa-train.cpp.