TRF Language Model
main-sa-train.cpp File Reference
#include "hrf-sams.h"

Go to the source code of this file.

Functions

opt Add (wbOPT_STRING, "feat", &cfg_pathFeatStyle, "a feature style file. Set this value will disable -order")
 
opt Add (wbOPT_INT, "order", &cfg_nFeatOrder, "the ngram feature order")
 
opt Add (wbOPT_INT, "len", &cfg_nMaxLen, "the maximum length of HRF")
 
opt Add (wbOPT_INT, "layer", &cfg_nHLayer, "the hidden layer of HRF")
 
opt Add (wbOPT_INT, "node", &cfg_nHNode, "the hidden node of each hidden layer of HRF")
 
opt Add (wbOPT_STRING, "train", &cfg_pathTrain, "Training corpus (TXT)")
 
opt Add (wbOPT_STRING, "valid", &cfg_pathValid, "valid corpus (TXT)")
 
opt Add (wbOPT_STRING, "test", &cfg_pathTest, "test corpus (TXT)")
 
opt Add (wbOPT_STRING, "read", &cfg_pathModelRead, "Read the init model to train")
 
opt Add (wbOPT_STRING, "write", &cfg_pathModelWrite, "Output model")
 
opt Add (wbOPT_INT, "iter", &cfg_nIterTotalNum, "iter total number")
 
opt Add (wbOPT_INT, "thread", &cfg_nThread, "The thread number")
 
opt Add (wbOPT_INT, "mini-batch", &cfg_nMiniBatch, "mini-batch")
 
opt Add (wbOPT_INT, "t0", &cfg_t0, "t0")
 
opt Add (wbOPT_STRING, "gamma-lambda", &cfg_gamma_lambda, "learning rate of lambda")
 
opt Add (wbOPT_STRING, "gamma-hidden", &cfg_gamma_hidden, "learning rate of VHmatrix")
 
opt Add (wbOPT_STRING, "gamma-zeta", &cfg_gamma_zeta, "learning rate of zeta")
 
opt Add (wbOPT_STRING, "gamma-var", &cfg_gamma_var, "learning rate of variance")
 
opt Add (wbOPT_FLOAT, "momentum", &cfg_fMomentum, "the momentum")
 
opt Add (wbOPT_TRUE, "update-lambda", &cfg_bUpdateLambda, "update lambda")
 
opt Add (wbOPT_TRUE, "update-zeta", &cfg_bUpdateZeta, "update zeta")
 
opt Add (wbOPT_INT, "tavg", &cfg_nAvgBeg, ">0 then apply averaging")
 
opt Add (wbOPT_FLOAT, "vgap", &cfg_var_gap, "the threshold of variance")
 
opt Add (wbOPT_FLOAT, "dgap", &cfg_dir_gap, "the threshold for parameter update")
 
opt Add (wbOPT_FLOAT, "zgap", &cfg_zeta_gap, "the threshold for zeta update")
 
opt Add (wbOPT_FLOAT, "L2", &cfg_fRegL2, "regularization L2")
 
opt Add (wbOPT_TRUE, "init", &cfg_bInitValue, "Re-init the parameters")
 
opt Add (wbOPT_TRUE, "zero-init", &cfg_bZeroInit, "Set the init parameters Zero. Otherwise random init the parameters")
 
opt Add (wbOPT_INT, "print-per-iter", &cfg_nPrintPerIter, "print the LL per iterations")
 
opt Add (wbOPT_TRUE, "not-print-train", &cfg_bUnprintTrain, "donot print LL on training set")
 
opt Add (wbOPT_TRUE, "not-print-valid", &cfg_bUnprintValid, "donot print LL on valid set")
 
opt Add (wbOPT_TRUE, "not-print-test", &cfg_bUnprintTest, "donot print LL on test set")
 
opt Add (wbOPT_STRING, "write-at-iter", &cfg_strWriteAtIter, "write the LL per iteration, such as [1:100:1000]")
 
opt Add (wbOPT_STRING, "write-mean", &cfg_pathWriteMean, "write the expecataion on training set")
 
opt Add (wbOPT_STRING, "write-var", &cfg_pathWriteVar, "write the variance on training set")
 
opt Add (wbOPT_INT, "AIS-chain", &cfg_AIS_for_LL.nChain, "AIS chain number")
 
opt Add (wbOPT_INT, "AIS-inter", &cfg_AIS_for_LL.nInter, "AIS intermediate distribution number")
 
opt Parse (_argc, _argv)
 
lout<< "*********************************************"<< endl;lout<< " TRF_SAtrain.exe { by Bin Wang } "<< endl;lout<< "\"<< __DATE__<< "\"<< __TIME__<< "\"<< endl;lout<< "**********************************************"<< endl;omp_set_num_threads(cfg_nThread);lout<< "[OMP] omp_thread = "<< omp_get_max_threads()<< endl;trf::omp_rand(cfg_nThread);Vocab *pv=new Vocab(cfg_pathVocab);Model m(pv, cfg_nHLayer, cfg_nHNode, cfg_nMaxLen);if(cfg_pathModelRead) { m.ReadT(cfg_pathModelRead);} else { m.LoadFromCorpus(cfg_pathTrain, cfg_pathFeatStyle, cfg_nFeatOrder);} lout_variable(m.m_hlayer);lout_variable(m.m_hnode);lout_variable(m.GetParamNum());trf::CorpusTxt *pTrain=(cfg_pathTrain) ? new trf::CorpusTxt(cfg_pathTrain) :NULL;trf::CorpusTxt *pValid=(cfg_pathValid) ? new trf::CorpusTxt(cfg_pathValid) :NULL;trf::CorpusTxt *pTest=(cfg_pathTest) ? new trf::CorpusTxt(cfg_pathTest) :NULL;Train *pFunc;if(cfg_bUpdateZeta) { pFunc=new SAMSZeta;} else if(cfg_bUpdateLambda) { pFunc=new SALambda;} pFunc-> OpenTempFile (cfg_pathModelWrite)
 
 VecUnfold (cfg_strWriteAtIter, pFunc->m_aWriteAtIter)
 
 if (cfg_bUpdateZeta)
 
else if (cfg_bUpdateLambda)
 
pFunc Run (bInit)
 
m WriteT (cfg_pathModelWrite)
 
 SAFE_DELETE (pTrain)
 
 SAFE_DELETE (pValid)
 
 SAFE_DELETE (pTest)
 
 SAFE_DELETE (pv)
 

Variables

char * cfg_pathVocab = NULL
 
int cfg_nHLayer = 1
 
int cfg_nHNode = 2
 
int cfg_nFeatOrder = 2
 
char * cfg_pathFeatStyle = NULL
 
int cfg_nMaxLen = 0
 
char * cfg_pathTrain = NULL
 
char * cfg_pathValid = NULL
 
char * cfg_pathTest = NULL
 
char * cfg_pathModelRead = NULL
 
char * cfg_pathModelWrite = NULL
 
int cfg_nThread = 1
 
int cfg_nIterTotalNum = 1000
 
int cfg_nMiniBatch = 300
 
int cfg_t0 = 500
 
char * cfg_gamma_lambda = "0,0.8"
 
char * cfg_gamma_hidden = "100,0.8"
 
char * cfg_gamma_zeta = "0,0.6"
 
char * cfg_gamma_var = "0,0.8"
 
float cfg_fMomentum = 0
 
float cfg_var_gap = 1e-4
 
float cfg_dir_gap = 1
 
float cfg_zeta_gap = 10
 
bool cfg_bUpdateLambda = false
 
bool cfg_bUpdateZeta = false
 
int cfg_nAvgBeg = 0
 
float cfg_fRegL2 = 0
 
bool cfg_bInitValue = false
 
bool cfg_bZeroInit = false
 
int cfg_nPrintPerIter = 100
 
bool cfg_bUnprintTrain = false
 
bool cfg_bUnprintValid = false
 
bool cfg_bUnprintTest = false
 
char * cfg_strWriteAtIter = NULL
 
char * cfg_pathWriteMean = NULL
 
char * cfg_pathWriteVar = NULL
 
AISConfig cfg_AIS_for_LL
 
Option opt
 
 _wbMain
 
pFunc Reset & m
 
pFunc m_nMinibatch = cfg_nMiniBatch
 
pFunc m_nAvgBeg = cfg_nAvgBeg
 
pFunc m_fRegL2 = cfg_fRegL2
 
pFunc m_aPrint [0] = !cfg_bUnprintTrain
 
pFunc m_nPrintPerIter = cfg_nPrintPerIter
 
pFunc m_nIterMax = cfg_nIterTotalNum
 
pFunc m_AISConfigForP = cfg_AIS_for_LL
 
pFunc m_AISConfigForZ = cfg_AIS_for_LL
 
bool bInit = (!cfg_pathModelRead) || cfg_bInitValue
 
 return
 

Function Documentation

§ Add() [1/37]

opt Add ( wbOPT_STRING  ,
"feat"  ,
cfg_pathFeatStyle,
"a feature style file. Set this value will disable -order"   
)

§ Add() [2/37]

opt Add ( wbOPT_INT  ,
"order"  ,
cfg_nFeatOrder,
"the ngram feature order"   
)

§ Add() [3/37]

opt Add ( wbOPT_INT  ,
"len"  ,
cfg_nMaxLen,
"the maximum length of HRF"   
)

§ Add() [4/37]

opt Add ( wbOPT_INT  ,
"layer"  ,
cfg_nHLayer,
"the hidden layer of HRF"   
)

§ Add() [5/37]

opt Add ( wbOPT_INT  ,
"node"  ,
cfg_nHNode,
"the hidden node of each hidden layer of HRF"   
)

§ Add() [6/37]

opt Add ( wbOPT_STRING  ,
"train"  ,
cfg_pathTrain,
"Training corpus (TXT)"   
)

§ Add() [7/37]

opt Add ( wbOPT_STRING  ,
"valid"  ,
cfg_pathValid,
"valid corpus (TXT)"   
)

§ Add() [8/37]

opt Add ( wbOPT_STRING  ,
"test"  ,
cfg_pathTest,
"test corpus (TXT)"   
)

§ Add() [9/37]

opt Add ( wbOPT_STRING  ,
"read"  ,
cfg_pathModelRead,
"Read the init model to train"   
)

§ Add() [10/37]

opt Add ( wbOPT_STRING  ,
"write"  ,
cfg_pathModelWrite,
"Output model"   
)

§ Add() [11/37]

opt Add ( wbOPT_INT  ,
"iter"  ,
cfg_nIterTotalNum,
"iter total number"   
)

§ Add() [12/37]

opt Add ( wbOPT_INT  ,
"thread"  ,
cfg_nThread,
"The thread number"   
)

§ Add() [13/37]

opt Add ( wbOPT_INT  ,
"mini-batch"  ,
cfg_nMiniBatch,
"mini-batch"   
)

§ Add() [14/37]

opt Add ( wbOPT_INT  ,
"t0"  ,
cfg_t0,
"t0"   
)

§ Add() [15/37]

opt Add ( wbOPT_STRING  ,
"gamma-lambda"  ,
cfg_gamma_lambda,
"learning rate of lambda"   
)

§ Add() [16/37]

opt Add ( wbOPT_STRING  ,
"gamma-hidden"  ,
cfg_gamma_hidden,
"learning rate of VHmatrix"   
)

§ Add() [17/37]

opt Add ( wbOPT_STRING  ,
"gamma-zeta"  ,
cfg_gamma_zeta,
"learning rate of zeta"   
)

§ Add() [18/37]

opt Add ( wbOPT_STRING  ,
"gamma-var"  ,
cfg_gamma_var,
"learning rate of variance"   
)

§ Add() [19/37]

opt Add ( wbOPT_FLOAT  ,
"momentum"  ,
cfg_fMomentum,
"the momentum"   
)

§ Add() [20/37]

opt Add ( wbOPT_TRUE  ,
"update-lambda"  ,
cfg_bUpdateLambda,
"update lambda"   
)

§ Add() [21/37]

opt Add ( wbOPT_TRUE  ,
"update-zeta"  ,
cfg_bUpdateZeta,
"update zeta"   
)

§ Add() [22/37]

opt Add ( wbOPT_INT  ,
"tavg"  ,
cfg_nAvgBeg,
,
0 then apply averaging"   
)

§ Add() [23/37]

opt Add ( wbOPT_FLOAT  ,
"vgap"  ,
cfg_var_gap,
"the threshold of variance"   
)

§ Add() [24/37]

opt Add ( wbOPT_FLOAT  ,
"dgap"  ,
cfg_dir_gap,
"the threshold for parameter update"   
)

§ Add() [25/37]

opt Add ( wbOPT_FLOAT  ,
"zgap"  ,
cfg_zeta_gap,
"the threshold for zeta update"   
)

§ Add() [26/37]

opt Add ( wbOPT_FLOAT  ,
"L2"  ,
cfg_fRegL2,
"regularization L2"   
)

§ Add() [27/37]

opt Add ( wbOPT_TRUE  ,
"init"  ,
cfg_bInitValue,
"Re-init the parameters"   
)

§ Add() [28/37]

opt Add ( wbOPT_TRUE  ,
"zero-init"  ,
cfg_bZeroInit,
"Set the init parameters Zero. Otherwise random init the parameters"   
)

§ Add() [29/37]

opt Add ( wbOPT_INT  ,
"print-per-iter"  ,
cfg_nPrintPerIter,
"print the LL per iterations"   
)

§ Add() [30/37]

opt Add ( wbOPT_TRUE  ,
"not-print-train"  ,
cfg_bUnprintTrain,
"donot print LL on training set"   
)

§ Add() [31/37]

opt Add ( wbOPT_TRUE  ,
"not-print-valid"  ,
cfg_bUnprintValid,
"donot print LL on valid set"   
)

§ Add() [32/37]

opt Add ( wbOPT_TRUE  ,
"not-print-test"  ,
cfg_bUnprintTest,
"donot print LL on test set"   
)

§ Add() [33/37]

opt Add ( wbOPT_STRING  ,
"write-at-iter"  ,
cfg_strWriteAtIter,
"write the LL per  iteration,
such as "  [1:100:1000] 
)

§ Add() [34/37]

opt Add ( wbOPT_STRING  ,
"write-mean"  ,
cfg_pathWriteMean,
"write the expecataion on training set"   
)

§ Add() [35/37]

opt Add ( wbOPT_STRING  ,
"write-var"  ,
cfg_pathWriteVar,
"write the variance on training set"   
)

§ Add() [36/37]

opt Add ( wbOPT_INT  ,
"AIS-chain"  ,
&cfg_AIS_for_LL.  nChain,
"AIS chain number"   
)

§ Add() [37/37]

opt Add ( wbOPT_INT  ,
"AIS-inter"  ,
&cfg_AIS_for_LL.  nInter,
"AIS intermediate distribution number"   
)

§ if() [1/2]

if ( cfg_bUpdateZeta  )

Definition at line 169 of file main-sa-train.cpp.

§ if() [2/2]

else if ( cfg_bUpdateLambda  )

Definition at line 174 of file main-sa-train.cpp.

§ OpenTempFile()

lout<< "*********************************************" << endl; lout << " TRF_SAtrain.exe { by Bin Wang } " << endl; lout << "\t" << __DATE__ << "\t" << __TIME__ << "\t" << endl; lout << "**********************************************" << endl; omp_set_num_threads(cfg_nThread); lout << "[OMP] omp_thread = " << omp_get_max_threads() << endl; trf::omp_rand(cfg_nThread); Vocab *pv = new Vocab(cfg_pathVocab); Model m(pv, cfg_nHLayer, cfg_nHNode, cfg_nMaxLen); if (cfg_pathModelRead) { m.ReadT(cfg_pathModelRead); } else { m.LoadFromCorpus(cfg_pathTrain, cfg_pathFeatStyle, cfg_nFeatOrder); } lout_variable(m.m_hlayer); lout_variable(m.m_hnode); lout_variable(m.GetParamNum()); trf::CorpusTxt *pTrain = (cfg_pathTrain) ? new trf::CorpusTxt(cfg_pathTrain) : NULL; trf::CorpusTxt *pValid = (cfg_pathValid) ? new trf::CorpusTxt(cfg_pathValid) : NULL; trf::CorpusTxt *pTest = (cfg_pathTest) ? new trf::CorpusTxt(cfg_pathTest) : NULL; Train *pFunc; if (cfg_bUpdateZeta) { pFunc = new SAMSZeta; } else if (cfg_bUpdateLambda) { pFunc = new SALambda; } pFunc-> OpenTempFile ( cfg_pathModelWrite  )

§ Parse()

opt Parse ( _argc  ,
_argv   
)

§ Run()

pFunc Run ( bInit  )

§ SAFE_DELETE() [1/4]

SAFE_DELETE ( pTrain  )

§ SAFE_DELETE() [2/4]

SAFE_DELETE ( pValid  )

§ SAFE_DELETE() [3/4]

SAFE_DELETE ( pTest  )

§ SAFE_DELETE() [4/4]

SAFE_DELETE ( pv  )

§ VecUnfold()

VecUnfold ( cfg_strWriteAtIter  ,
pFunc->  m_aWriteAtIter 
)

§ WriteT()

m WriteT ( cfg_pathModelWrite  )

Variable Documentation

§ _wbMain

_wbMain
Initial value:
{
opt.Add(wbOPT_STRING, "vocab", &cfg_pathVocab, "The vocabulary")
Option opt
char * cfg_pathVocab

Definition at line 72 of file main-sa-train.cpp.

§ bInit

bool bInit = (!cfg_pathModelRead) || cfg_bInitValue

Definition at line 186 of file main-sa-train.cpp.

§ cfg_AIS_for_LL

AISConfig cfg_AIS_for_LL

Definition at line 67 of file main-sa-train.cpp.

§ cfg_bInitValue

bool cfg_bInitValue = false

Definition at line 56 of file main-sa-train.cpp.

§ cfg_bUnprintTest

bool cfg_bUnprintTest = false

Definition at line 61 of file main-sa-train.cpp.

§ cfg_bUnprintTrain

bool cfg_bUnprintTrain = false

Definition at line 59 of file main-sa-train.cpp.

§ cfg_bUnprintValid

bool cfg_bUnprintValid = false

Definition at line 60 of file main-sa-train.cpp.

§ cfg_bUpdateLambda

bool cfg_bUpdateLambda = false

Definition at line 50 of file main-sa-train.cpp.

§ cfg_bUpdateZeta

bool cfg_bUpdateZeta = false

Definition at line 51 of file main-sa-train.cpp.

§ cfg_bZeroInit

bool cfg_bZeroInit = false

Definition at line 57 of file main-sa-train.cpp.

§ cfg_dir_gap

float cfg_dir_gap = 1

Definition at line 48 of file main-sa-train.cpp.

§ cfg_fMomentum

float cfg_fMomentum = 0

Definition at line 46 of file main-sa-train.cpp.

§ cfg_fRegL2

float cfg_fRegL2 = 0

Definition at line 54 of file main-sa-train.cpp.

§ cfg_gamma_hidden

char* cfg_gamma_hidden = "100,0.8"

Definition at line 43 of file main-sa-train.cpp.

§ cfg_gamma_lambda

char* cfg_gamma_lambda = "0,0.8"

Definition at line 42 of file main-sa-train.cpp.

§ cfg_gamma_var

char* cfg_gamma_var = "0,0.8"

Definition at line 45 of file main-sa-train.cpp.

§ cfg_gamma_zeta

char* cfg_gamma_zeta = "0,0.6"

Definition at line 44 of file main-sa-train.cpp.

§ cfg_nAvgBeg

int cfg_nAvgBeg = 0

Definition at line 52 of file main-sa-train.cpp.

§ cfg_nFeatOrder

int cfg_nFeatOrder = 2

Definition at line 26 of file main-sa-train.cpp.

§ cfg_nHLayer

int cfg_nHLayer = 1

Definition at line 24 of file main-sa-train.cpp.

§ cfg_nHNode

int cfg_nHNode = 2

Definition at line 25 of file main-sa-train.cpp.

§ cfg_nIterTotalNum

int cfg_nIterTotalNum = 1000

Definition at line 39 of file main-sa-train.cpp.

§ cfg_nMaxLen

int cfg_nMaxLen = 0

Definition at line 28 of file main-sa-train.cpp.

§ cfg_nMiniBatch

int cfg_nMiniBatch = 300

Definition at line 40 of file main-sa-train.cpp.

§ cfg_nPrintPerIter

int cfg_nPrintPerIter = 100

Definition at line 58 of file main-sa-train.cpp.

§ cfg_nThread

int cfg_nThread = 1

Definition at line 37 of file main-sa-train.cpp.

§ cfg_pathFeatStyle

char* cfg_pathFeatStyle = NULL

Definition at line 27 of file main-sa-train.cpp.

§ cfg_pathModelRead

char* cfg_pathModelRead = NULL

Definition at line 34 of file main-sa-train.cpp.

§ cfg_pathModelWrite

char* cfg_pathModelWrite = NULL

Definition at line 35 of file main-sa-train.cpp.

§ cfg_pathTest

char* cfg_pathTest = NULL

Definition at line 32 of file main-sa-train.cpp.

§ cfg_pathTrain

char* cfg_pathTrain = NULL

Definition at line 30 of file main-sa-train.cpp.

§ cfg_pathValid

char* cfg_pathValid = NULL

Definition at line 31 of file main-sa-train.cpp.

§ cfg_pathVocab

char* cfg_pathVocab = NULL

Definition at line 23 of file main-sa-train.cpp.

§ cfg_pathWriteMean

char* cfg_pathWriteMean = NULL

Definition at line 64 of file main-sa-train.cpp.

§ cfg_pathWriteVar

char* cfg_pathWriteVar = NULL

Definition at line 65 of file main-sa-train.cpp.

§ cfg_strWriteAtIter

char* cfg_strWriteAtIter = NULL

Definition at line 62 of file main-sa-train.cpp.

§ cfg_t0

int cfg_t0 = 500

Definition at line 41 of file main-sa-train.cpp.

§ cfg_var_gap

float cfg_var_gap = 1e-4

Definition at line 47 of file main-sa-train.cpp.

§ cfg_zeta_gap

float cfg_zeta_gap = 10

Definition at line 49 of file main-sa-train.cpp.

§ m

pFunc Reset& m

Definition at line 156 of file main-sa-train.cpp.

§ m_AISConfigForP

pFunc m_AISConfigForP = cfg_AIS_for_LL

Definition at line 166 of file main-sa-train.cpp.

§ m_AISConfigForZ

pFunc m_AISConfigForZ = cfg_AIS_for_LL

Definition at line 167 of file main-sa-train.cpp.

§ m_aPrint

pFunc m_aPrint = !cfg_bUnprintTrain

Definition at line 160 of file main-sa-train.cpp.

§ m_fRegL2

pFunc m_fRegL2 = cfg_fRegL2

Definition at line 159 of file main-sa-train.cpp.

§ m_nAvgBeg

pFunc m_nAvgBeg = cfg_nAvgBeg

Definition at line 158 of file main-sa-train.cpp.

§ m_nIterMax

pFunc m_nIterMax = cfg_nIterTotalNum

Definition at line 165 of file main-sa-train.cpp.

§ m_nMinibatch

pFunc m_nMinibatch = cfg_nMiniBatch

Definition at line 157 of file main-sa-train.cpp.

§ m_nPrintPerIter

pFunc m_nPrintPerIter = cfg_nPrintPerIter

Definition at line 163 of file main-sa-train.cpp.

§ opt

Option opt

Definition at line 69 of file main-sa-train.cpp.

§ return

return

Definition at line 198 of file main-sa-train.cpp.