|
TRF Language Model
|
#include <hrf-sa-train.h>
Public Member Functions | |
| SAtrain (SAfunc *pfunc=NULL) | |
| virtual bool | Run (const double *pInitParams=NULL) |
| Run iteration. input the init-parameters. More... | |
| void | UpdateGamma (int nIterNum) |
| Update the learning rate. More... | |
| void | UpdateDir (double *pDir, double *pGradient, const double *pParam) |
| compute the update direction More... | |
| virtual void | Update (double *pdParam, const double *pdDir, double dStep) |
| Update the parameters. More... | |
| void | PrintInfo () |
| Print Information. More... | |
| int | CutValue (double *p, int num, double gap) |
| cut array More... | |
Public Member Functions inherited from wb::Solve | |
| Solve (Func *pfunc=NULL, double dtol=1e-5) | |
| virtual void | IterInit () |
| initial the iteration, for derivation. More... | |
| virtual void | ComputeDir (int k, const double *pdParam, const double *pdGradient, double *pdDir) |
| Calculate the update direction p_k. More... | |
| virtual double | LineSearch (double *pdDir, double dValue, const double *pdParam, const double *pdGradient) |
| linear search. More... | |
| virtual bool | StopDecision (int k, double dValue, const double *pdGradient) |
| Stop decision. More... | |
Public Attributes | |
| LearningRate | m_gain_lambda |
| LearningRate | m_gain_hidden |
| LearningRate | m_gain_zeta |
| bool | m_bUpdate_lambda |
| bool | m_bUpdate_zeta |
| double | m_dir_gap |
| double | m_zeta_gap |
| float | m_fMomentum |
| the momentum More... | |
| int | m_nAvgBeg |
| if >0, then calculate the average More... | |
| double | m_fEpochNum |
| the current epoch number More... | |
| int | m_nPrintPerIter |
| output the LL per iteration, if ==-1, the disable More... | |
| wb::Array< int > | m_aWriteAtIter |
| output temp model at some iteration More... | |
Public Attributes inherited from wb::Solve | |
| Func * | m_pfunc |
| pointer to the function More... | |
| double * | m_pdRoot |
| save the root of the function More... | |
| int | m_nIterNum |
| current iteration number, iter form m_nIterMin to m_nIterMax More... | |
| int | m_nIterMin |
| minium iteration number More... | |
| int | m_nIterMax |
| maximum iteration number More... | |
| double | m_dSpendMinute |
| record the iteration spend time��minute�� More... | |
| double | m_dStop |
| stop threshold More... | |
| double | m_dGain |
| itera step. ==0 means using the line search . More... | |
Protected Attributes | |
| double | m_gamma_lambda |
| double | m_gamma_hidden |
| double | m_gamma_zeta |
Protected Attributes inherited from wb::Solve | |
| const char * | m_pAlgorithmName |
| the algorithm name. More... | |
Additional Inherited Members | |
Static Public Member Functions inherited from wb::Solve | |
| static double | VecProduct (const double *pdVec1, const double *pdVec2, int nSize) |
| calculate the dot of two vectors More... | |
| static double | VecNorm (const double *pdVec, int nSize) |
| calculate the norm of a vector More... | |
| static double | VecDist (const double *pdVec1, const double *pdVec2, int nSize) |
| calculate the distance of two vectors More... | |
Definition at line 225 of file hrf-sa-train.h.
|
inline |
Definition at line 258 of file hrf-sa-train.h.
| int hrf::SAtrain::CutValue | ( | double * | p, |
| int | num, | ||
| double | gap | ||
| ) |
cut array
Definition at line 1291 of file hrf-sa-train.cpp.
| void hrf::SAtrain::PrintInfo | ( | ) |
Print Information.
Definition at line 1277 of file hrf-sa-train.cpp.
|
virtual |
Run iteration. input the init-parameters.
init
< set average
< set back
Reimplemented from wb::Solve.
Definition at line 1015 of file hrf-sa-train.cpp.
|
virtual |
Update the parameters.
Reimplemented from wb::Solve.
Definition at line 1227 of file hrf-sa-train.cpp.
| void hrf::SAtrain::UpdateDir | ( | double * | pDir, |
| double * | pGradient, | ||
| const double * | pParam | ||
| ) |
compute the update direction
Definition at line 1182 of file hrf-sa-train.cpp.
| void hrf::SAtrain::UpdateGamma | ( | int | nIterNum | ) |
Update the learning rate.
Definition at line 1161 of file hrf-sa-train.cpp.
| wb::Array<int> hrf::SAtrain::m_aWriteAtIter |
output temp model at some iteration
Definition at line 250 of file hrf-sa-train.h.
| bool hrf::SAtrain::m_bUpdate_lambda |
Definition at line 238 of file hrf-sa-train.h.
| bool hrf::SAtrain::m_bUpdate_zeta |
Definition at line 239 of file hrf-sa-train.h.
| double hrf::SAtrain::m_dir_gap |
Definition at line 241 of file hrf-sa-train.h.
| double hrf::SAtrain::m_fEpochNum |
the current epoch number
Definition at line 247 of file hrf-sa-train.h.
| float hrf::SAtrain::m_fMomentum |
the momentum
Definition at line 244 of file hrf-sa-train.h.
| LearningRate hrf::SAtrain::m_gain_hidden |
Definition at line 235 of file hrf-sa-train.h.
| LearningRate hrf::SAtrain::m_gain_lambda |
Definition at line 234 of file hrf-sa-train.h.
| LearningRate hrf::SAtrain::m_gain_zeta |
Definition at line 236 of file hrf-sa-train.h.
|
protected |
Definition at line 229 of file hrf-sa-train.h.
|
protected |
Definition at line 228 of file hrf-sa-train.h.
|
protected |
Definition at line 230 of file hrf-sa-train.h.
| int hrf::SAtrain::m_nAvgBeg |
if >0, then calculate the average
Definition at line 245 of file hrf-sa-train.h.
| int hrf::SAtrain::m_nPrintPerIter |
output the LL per iteration, if ==-1, the disable
Definition at line 249 of file hrf-sa-train.h.
| double hrf::SAtrain::m_zeta_gap |
Definition at line 242 of file hrf-sa-train.h.