TRF Language Model
|
#include <trf-sa-train.h>
Public Member Functions | |
SAtrain (SAfunc *pfunc=NULL) | |
virtual bool | Run (const double *pInitParams=NULL) |
Run iteration. input the init-parameters. More... | |
void | UpdateGamma (int nIterNum) |
Update the learning rate. More... | |
void | UpdateDir (double *pDir, double *pGradient, const double *pParam) |
compute the update direction More... | |
virtual void | Update (double *pdParam, const double *pdDir, double dStep) |
Update the parameters. More... | |
void | PrintInfo () |
Print Information. More... | |
Public Member Functions inherited from wb::Solve | |
Solve (Func *pfunc=NULL, double dtol=1e-5) | |
virtual void | IterInit () |
initial the iteration, for derivation. More... | |
virtual void | ComputeDir (int k, const double *pdParam, const double *pdGradient, double *pdDir) |
Calculate the update direction p_k. More... | |
virtual double | LineSearch (double *pdDir, double dValue, const double *pdParam, const double *pdGradient) |
linear search. More... | |
virtual bool | StopDecision (int k, double dValue, const double *pdGradient) |
Stop decision. More... | |
Public Attributes | |
LearningRate | m_gain_lambda |
LearningRate | m_gain_zeta |
bool | m_bUpdate_lambda |
bool | m_bUpdate_zeta |
double | m_dir_gap |
control the dir values More... | |
int | m_nAvgBeg |
if >0, then calculate the average More... | |
float | m_fEpochNun |
the current epoch number - double More... | |
int | m_nPrintPerIter |
output the LL per iteration, if ==-1, the disable More... | |
wb::Array< int > | m_aWriteAtIter |
output temp model at some iteration More... | |
Public Attributes inherited from wb::Solve | |
Func * | m_pfunc |
pointer to the function More... | |
double * | m_pdRoot |
save the root of the function More... | |
int | m_nIterNum |
current iteration number, iter form m_nIterMin to m_nIterMax More... | |
int | m_nIterMin |
minium iteration number More... | |
int | m_nIterMax |
maximum iteration number More... | |
double | m_dSpendMinute |
record the iteration spend time��minute�� More... | |
double | m_dStop |
stop threshold More... | |
double | m_dGain |
itera step. ==0 means using the line search . More... | |
Protected Attributes | |
double | m_gamma_lambda |
double | m_gamma_zeta |
Protected Attributes inherited from wb::Solve | |
const char * | m_pAlgorithmName |
the algorithm name. More... | |
Additional Inherited Members | |
Static Public Member Functions inherited from wb::Solve | |
static double | VecProduct (const double *pdVec1, const double *pdVec2, int nSize) |
calculate the dot of two vectors More... | |
static double | VecNorm (const double *pdVec, int nSize) |
calculate the norm of a vector More... | |
static double | VecDist (const double *pdVec1, const double *pdVec2, int nSize) |
calculate the distance of two vectors More... | |
Definition at line 188 of file trf-sa-train.h.
|
inline |
Definition at line 222 of file trf-sa-train.h.
void trf::SAtrain::PrintInfo | ( | ) |
Print Information.
Definition at line 795 of file trf-sa-train.cpp.
|
virtual |
Run iteration. input the init-parameters.
init
< set average
< set back
Reimplemented from wb::Solve.
Definition at line 550 of file trf-sa-train.cpp.
|
virtual |
Update the parameters.
Reimplemented from wb::Solve.
Definition at line 761 of file trf-sa-train.cpp.
void trf::SAtrain::UpdateDir | ( | double * | pDir, |
double * | pGradient, | ||
const double * | pParam | ||
) |
compute the update direction
Definition at line 701 of file trf-sa-train.cpp.
void trf::SAtrain::UpdateGamma | ( | int | nIterNum | ) |
Update the learning rate.
Definition at line 692 of file trf-sa-train.cpp.
wb::Array<int> trf::SAtrain::m_aWriteAtIter |
output temp model at some iteration
Definition at line 206 of file trf-sa-train.h.
bool trf::SAtrain::m_bUpdate_lambda |
Definition at line 199 of file trf-sa-train.h.
bool trf::SAtrain::m_bUpdate_zeta |
Definition at line 200 of file trf-sa-train.h.
double trf::SAtrain::m_dir_gap |
control the dir values
Definition at line 201 of file trf-sa-train.h.
float trf::SAtrain::m_fEpochNun |
the current epoch number - double
Definition at line 204 of file trf-sa-train.h.
LearningRate trf::SAtrain::m_gain_lambda |
Definition at line 195 of file trf-sa-train.h.
LearningRate trf::SAtrain::m_gain_zeta |
Definition at line 196 of file trf-sa-train.h.
|
protected |
Definition at line 191 of file trf-sa-train.h.
|
protected |
Definition at line 192 of file trf-sa-train.h.
int trf::SAtrain::m_nAvgBeg |
if >0, then calculate the average
Definition at line 203 of file trf-sa-train.h.
int trf::SAtrain::m_nPrintPerIter |
output the LL per iteration, if ==-1, the disable
Definition at line 205 of file trf-sa-train.h.