TRF Language Model
trf::SAtrain Class Reference

#include <trf-sa-train.h>

Inheritance diagram for trf::SAtrain:
wb::Solve

Public Member Functions

 SAtrain (SAfunc *pfunc=NULL)
 
virtual bool Run (const double *pInitParams=NULL)
 Run iteration. input the init-parameters. More...
 
void UpdateGamma (int nIterNum)
 Update the learning rate. More...
 
void UpdateDir (double *pDir, double *pGradient, const double *pParam)
 compute the update direction More...
 
virtual void Update (double *pdParam, const double *pdDir, double dStep)
 Update the parameters. More...
 
void PrintInfo ()
 Print Information. More...
 
- Public Member Functions inherited from wb::Solve
 Solve (Func *pfunc=NULL, double dtol=1e-5)
 
virtual void IterInit ()
 initial the iteration, for derivation. More...
 
virtual void ComputeDir (int k, const double *pdParam, const double *pdGradient, double *pdDir)
 Calculate the update direction p_k. More...
 
virtual double LineSearch (double *pdDir, double dValue, const double *pdParam, const double *pdGradient)
 linear search. More...
 
virtual bool StopDecision (int k, double dValue, const double *pdGradient)
 Stop decision. More...
 

Public Attributes

LearningRate m_gain_lambda
 
LearningRate m_gain_zeta
 
bool m_bUpdate_lambda
 
bool m_bUpdate_zeta
 
double m_dir_gap
 control the dir values More...
 
int m_nAvgBeg
 if >0, then calculate the average More...
 
float m_fEpochNun
 the current epoch number - double More...
 
int m_nPrintPerIter
 output the LL per iteration, if ==-1, the disable More...
 
wb::Array< int > m_aWriteAtIter
 output temp model at some iteration More...
 
- Public Attributes inherited from wb::Solve
Funcm_pfunc
 pointer to the function More...
 
double * m_pdRoot
 save the root of the function More...
 
int m_nIterNum
 current iteration number, iter form m_nIterMin to m_nIterMax More...
 
int m_nIterMin
 minium iteration number More...
 
int m_nIterMax
 maximum iteration number More...
 
double m_dSpendMinute
 record the iteration spend time��minute�� More...
 
double m_dStop
 stop threshold More...
 
double m_dGain
 itera step. ==0 means using the line search . More...
 

Protected Attributes

double m_gamma_lambda
 
double m_gamma_zeta
 
- Protected Attributes inherited from wb::Solve
const char * m_pAlgorithmName
 the algorithm name. More...
 

Additional Inherited Members

- Static Public Member Functions inherited from wb::Solve
static double VecProduct (const double *pdVec1, const double *pdVec2, int nSize)
 calculate the dot of two vectors More...
 
static double VecNorm (const double *pdVec, int nSize)
 calculate the norm of a vector More...
 
static double VecDist (const double *pdVec1, const double *pdVec2, int nSize)
 calculate the distance of two vectors More...
 

Detailed Description

Definition at line 188 of file trf-sa-train.h.

Constructor & Destructor Documentation

§ SAtrain()

trf::SAtrain::SAtrain ( SAfunc pfunc = NULL)
inline

Definition at line 222 of file trf-sa-train.h.

Member Function Documentation

§ PrintInfo()

void trf::SAtrain::PrintInfo ( )

Print Information.

Definition at line 795 of file trf-sa-train.cpp.

§ Run()

bool trf::SAtrain::Run ( const double *  pInitParams = NULL)
virtual

Run iteration. input the init-parameters.

init

< set average

< set back

Reimplemented from wb::Solve.

Definition at line 550 of file trf-sa-train.cpp.

§ Update()

void trf::SAtrain::Update ( double *  pdParam,
const double *  pdDir,
double  dStep 
)
virtual

Update the parameters.

Reimplemented from wb::Solve.

Definition at line 761 of file trf-sa-train.cpp.

§ UpdateDir()

void trf::SAtrain::UpdateDir ( double *  pDir,
double *  pGradient,
const double *  pParam 
)

compute the update direction

Definition at line 701 of file trf-sa-train.cpp.

§ UpdateGamma()

void trf::SAtrain::UpdateGamma ( int  nIterNum)

Update the learning rate.

Definition at line 692 of file trf-sa-train.cpp.

Member Data Documentation

§ m_aWriteAtIter

wb::Array<int> trf::SAtrain::m_aWriteAtIter

output temp model at some iteration

Definition at line 206 of file trf-sa-train.h.

§ m_bUpdate_lambda

bool trf::SAtrain::m_bUpdate_lambda

Definition at line 199 of file trf-sa-train.h.

§ m_bUpdate_zeta

bool trf::SAtrain::m_bUpdate_zeta

Definition at line 200 of file trf-sa-train.h.

§ m_dir_gap

double trf::SAtrain::m_dir_gap

control the dir values

Definition at line 201 of file trf-sa-train.h.

§ m_fEpochNun

float trf::SAtrain::m_fEpochNun

the current epoch number - double

Definition at line 204 of file trf-sa-train.h.

§ m_gain_lambda

LearningRate trf::SAtrain::m_gain_lambda

Definition at line 195 of file trf-sa-train.h.

§ m_gain_zeta

LearningRate trf::SAtrain::m_gain_zeta

Definition at line 196 of file trf-sa-train.h.

§ m_gamma_lambda

double trf::SAtrain::m_gamma_lambda
protected

Definition at line 191 of file trf-sa-train.h.

§ m_gamma_zeta

double trf::SAtrain::m_gamma_zeta
protected

Definition at line 192 of file trf-sa-train.h.

§ m_nAvgBeg

int trf::SAtrain::m_nAvgBeg

if >0, then calculate the average

Definition at line 203 of file trf-sa-train.h.

§ m_nPrintPerIter

int trf::SAtrain::m_nPrintPerIter

output the LL per iteration, if ==-1, the disable

Definition at line 205 of file trf-sa-train.h.


The documentation for this class was generated from the following files: