TRF Language Model
|
#include <hrf-sa-train.h>
Public Member Functions | |
SAtrain (SAfunc *pfunc=NULL) | |
virtual bool | Run (const double *pInitParams=NULL) |
Run iteration. input the init-parameters. More... | |
void | UpdateGamma (int nIterNum) |
Update the learning rate. More... | |
void | UpdateDir (double *pDir, double *pGradient, const double *pParam) |
compute the update direction More... | |
virtual void | Update (double *pdParam, const double *pdDir, double dStep) |
Update the parameters. More... | |
void | PrintInfo () |
Print Information. More... | |
int | CutValue (double *p, int num, double gap) |
cut array More... | |
Public Member Functions inherited from wb::Solve | |
Solve (Func *pfunc=NULL, double dtol=1e-5) | |
virtual void | IterInit () |
initial the iteration, for derivation. More... | |
virtual void | ComputeDir (int k, const double *pdParam, const double *pdGradient, double *pdDir) |
Calculate the update direction p_k. More... | |
virtual double | LineSearch (double *pdDir, double dValue, const double *pdParam, const double *pdGradient) |
linear search. More... | |
virtual bool | StopDecision (int k, double dValue, const double *pdGradient) |
Stop decision. More... | |
Public Attributes | |
LearningRate | m_gain_lambda |
LearningRate | m_gain_hidden |
LearningRate | m_gain_zeta |
bool | m_bUpdate_lambda |
bool | m_bUpdate_zeta |
double | m_dir_gap |
double | m_zeta_gap |
float | m_fMomentum |
the momentum More... | |
int | m_nAvgBeg |
if >0, then calculate the average More... | |
double | m_fEpochNum |
the current epoch number More... | |
int | m_nPrintPerIter |
output the LL per iteration, if ==-1, the disable More... | |
wb::Array< int > | m_aWriteAtIter |
output temp model at some iteration More... | |
Public Attributes inherited from wb::Solve | |
Func * | m_pfunc |
pointer to the function More... | |
double * | m_pdRoot |
save the root of the function More... | |
int | m_nIterNum |
current iteration number, iter form m_nIterMin to m_nIterMax More... | |
int | m_nIterMin |
minium iteration number More... | |
int | m_nIterMax |
maximum iteration number More... | |
double | m_dSpendMinute |
record the iteration spend time��minute�� More... | |
double | m_dStop |
stop threshold More... | |
double | m_dGain |
itera step. ==0 means using the line search . More... | |
Protected Attributes | |
double | m_gamma_lambda |
double | m_gamma_hidden |
double | m_gamma_zeta |
Protected Attributes inherited from wb::Solve | |
const char * | m_pAlgorithmName |
the algorithm name. More... | |
Additional Inherited Members | |
Static Public Member Functions inherited from wb::Solve | |
static double | VecProduct (const double *pdVec1, const double *pdVec2, int nSize) |
calculate the dot of two vectors More... | |
static double | VecNorm (const double *pdVec, int nSize) |
calculate the norm of a vector More... | |
static double | VecDist (const double *pdVec1, const double *pdVec2, int nSize) |
calculate the distance of two vectors More... | |
Definition at line 225 of file hrf-sa-train.h.
|
inline |
Definition at line 258 of file hrf-sa-train.h.
int hrf::SAtrain::CutValue | ( | double * | p, |
int | num, | ||
double | gap | ||
) |
cut array
Definition at line 1291 of file hrf-sa-train.cpp.
void hrf::SAtrain::PrintInfo | ( | ) |
Print Information.
Definition at line 1277 of file hrf-sa-train.cpp.
|
virtual |
Run iteration. input the init-parameters.
init
< set average
< set back
Reimplemented from wb::Solve.
Definition at line 1015 of file hrf-sa-train.cpp.
|
virtual |
Update the parameters.
Reimplemented from wb::Solve.
Definition at line 1227 of file hrf-sa-train.cpp.
void hrf::SAtrain::UpdateDir | ( | double * | pDir, |
double * | pGradient, | ||
const double * | pParam | ||
) |
compute the update direction
Definition at line 1182 of file hrf-sa-train.cpp.
void hrf::SAtrain::UpdateGamma | ( | int | nIterNum | ) |
Update the learning rate.
Definition at line 1161 of file hrf-sa-train.cpp.
wb::Array<int> hrf::SAtrain::m_aWriteAtIter |
output temp model at some iteration
Definition at line 250 of file hrf-sa-train.h.
bool hrf::SAtrain::m_bUpdate_lambda |
Definition at line 238 of file hrf-sa-train.h.
bool hrf::SAtrain::m_bUpdate_zeta |
Definition at line 239 of file hrf-sa-train.h.
double hrf::SAtrain::m_dir_gap |
Definition at line 241 of file hrf-sa-train.h.
double hrf::SAtrain::m_fEpochNum |
the current epoch number
Definition at line 247 of file hrf-sa-train.h.
float hrf::SAtrain::m_fMomentum |
the momentum
Definition at line 244 of file hrf-sa-train.h.
LearningRate hrf::SAtrain::m_gain_hidden |
Definition at line 235 of file hrf-sa-train.h.
LearningRate hrf::SAtrain::m_gain_lambda |
Definition at line 234 of file hrf-sa-train.h.
LearningRate hrf::SAtrain::m_gain_zeta |
Definition at line 236 of file hrf-sa-train.h.
|
protected |
Definition at line 229 of file hrf-sa-train.h.
|
protected |
Definition at line 228 of file hrf-sa-train.h.
|
protected |
Definition at line 230 of file hrf-sa-train.h.
int hrf::SAtrain::m_nAvgBeg |
if >0, then calculate the average
Definition at line 245 of file hrf-sa-train.h.
int hrf::SAtrain::m_nPrintPerIter |
output the LL per iteration, if ==-1, the disable
Definition at line 249 of file hrf-sa-train.h.
double hrf::SAtrain::m_zeta_gap |
Definition at line 242 of file hrf-sa-train.h.