TRF Language Model
hrf::SAtrain Class Reference

#include <hrf-sa-train.h>

Inheritance diagram for hrf::SAtrain:
wb::Solve

Public Member Functions

 SAtrain (SAfunc *pfunc=NULL)
 
virtual bool Run (const double *pInitParams=NULL)
 Run iteration. input the init-parameters. More...
 
void UpdateGamma (int nIterNum)
 Update the learning rate. More...
 
void UpdateDir (double *pDir, double *pGradient, const double *pParam)
 compute the update direction More...
 
virtual void Update (double *pdParam, const double *pdDir, double dStep)
 Update the parameters. More...
 
void PrintInfo ()
 Print Information. More...
 
int CutValue (double *p, int num, double gap)
 cut array More...
 
- Public Member Functions inherited from wb::Solve
 Solve (Func *pfunc=NULL, double dtol=1e-5)
 
virtual void IterInit ()
 initial the iteration, for derivation. More...
 
virtual void ComputeDir (int k, const double *pdParam, const double *pdGradient, double *pdDir)
 Calculate the update direction p_k. More...
 
virtual double LineSearch (double *pdDir, double dValue, const double *pdParam, const double *pdGradient)
 linear search. More...
 
virtual bool StopDecision (int k, double dValue, const double *pdGradient)
 Stop decision. More...
 

Public Attributes

LearningRate m_gain_lambda
 
LearningRate m_gain_hidden
 
LearningRate m_gain_zeta
 
bool m_bUpdate_lambda
 
bool m_bUpdate_zeta
 
double m_dir_gap
 
double m_zeta_gap
 
float m_fMomentum
 the momentum More...
 
int m_nAvgBeg
 if >0, then calculate the average More...
 
double m_fEpochNum
 the current epoch number More...
 
int m_nPrintPerIter
 output the LL per iteration, if ==-1, the disable More...
 
wb::Array< int > m_aWriteAtIter
 output temp model at some iteration More...
 
- Public Attributes inherited from wb::Solve
Funcm_pfunc
 pointer to the function More...
 
double * m_pdRoot
 save the root of the function More...
 
int m_nIterNum
 current iteration number, iter form m_nIterMin to m_nIterMax More...
 
int m_nIterMin
 minium iteration number More...
 
int m_nIterMax
 maximum iteration number More...
 
double m_dSpendMinute
 record the iteration spend time��minute�� More...
 
double m_dStop
 stop threshold More...
 
double m_dGain
 itera step. ==0 means using the line search . More...
 

Protected Attributes

double m_gamma_lambda
 
double m_gamma_hidden
 
double m_gamma_zeta
 
- Protected Attributes inherited from wb::Solve
const char * m_pAlgorithmName
 the algorithm name. More...
 

Additional Inherited Members

- Static Public Member Functions inherited from wb::Solve
static double VecProduct (const double *pdVec1, const double *pdVec2, int nSize)
 calculate the dot of two vectors More...
 
static double VecNorm (const double *pdVec, int nSize)
 calculate the norm of a vector More...
 
static double VecDist (const double *pdVec1, const double *pdVec2, int nSize)
 calculate the distance of two vectors More...
 

Detailed Description

Definition at line 225 of file hrf-sa-train.h.

Constructor & Destructor Documentation

§ SAtrain()

hrf::SAtrain::SAtrain ( SAfunc pfunc = NULL)
inline

Definition at line 258 of file hrf-sa-train.h.

Member Function Documentation

§ CutValue()

int hrf::SAtrain::CutValue ( double *  p,
int  num,
double  gap 
)

cut array

Definition at line 1291 of file hrf-sa-train.cpp.

§ PrintInfo()

void hrf::SAtrain::PrintInfo ( )

Print Information.

Definition at line 1277 of file hrf-sa-train.cpp.

§ Run()

bool hrf::SAtrain::Run ( const double *  pInitParams = NULL)
virtual

Run iteration. input the init-parameters.

init

< set average

< set back

Reimplemented from wb::Solve.

Definition at line 1015 of file hrf-sa-train.cpp.

§ Update()

void hrf::SAtrain::Update ( double *  pdParam,
const double *  pdDir,
double  dStep 
)
virtual

Update the parameters.

Reimplemented from wb::Solve.

Definition at line 1227 of file hrf-sa-train.cpp.

§ UpdateDir()

void hrf::SAtrain::UpdateDir ( double *  pDir,
double *  pGradient,
const double *  pParam 
)

compute the update direction

Definition at line 1182 of file hrf-sa-train.cpp.

§ UpdateGamma()

void hrf::SAtrain::UpdateGamma ( int  nIterNum)

Update the learning rate.

Definition at line 1161 of file hrf-sa-train.cpp.

Member Data Documentation

§ m_aWriteAtIter

wb::Array<int> hrf::SAtrain::m_aWriteAtIter

output temp model at some iteration

Definition at line 250 of file hrf-sa-train.h.

§ m_bUpdate_lambda

bool hrf::SAtrain::m_bUpdate_lambda

Definition at line 238 of file hrf-sa-train.h.

§ m_bUpdate_zeta

bool hrf::SAtrain::m_bUpdate_zeta

Definition at line 239 of file hrf-sa-train.h.

§ m_dir_gap

double hrf::SAtrain::m_dir_gap

Definition at line 241 of file hrf-sa-train.h.

§ m_fEpochNum

double hrf::SAtrain::m_fEpochNum

the current epoch number

Definition at line 247 of file hrf-sa-train.h.

§ m_fMomentum

float hrf::SAtrain::m_fMomentum

the momentum

Definition at line 244 of file hrf-sa-train.h.

§ m_gain_hidden

LearningRate hrf::SAtrain::m_gain_hidden

Definition at line 235 of file hrf-sa-train.h.

§ m_gain_lambda

LearningRate hrf::SAtrain::m_gain_lambda

Definition at line 234 of file hrf-sa-train.h.

§ m_gain_zeta

LearningRate hrf::SAtrain::m_gain_zeta

Definition at line 236 of file hrf-sa-train.h.

§ m_gamma_hidden

double hrf::SAtrain::m_gamma_hidden
protected

Definition at line 229 of file hrf-sa-train.h.

§ m_gamma_lambda

double hrf::SAtrain::m_gamma_lambda
protected

Definition at line 228 of file hrf-sa-train.h.

§ m_gamma_zeta

double hrf::SAtrain::m_gamma_zeta
protected

Definition at line 230 of file hrf-sa-train.h.

§ m_nAvgBeg

int hrf::SAtrain::m_nAvgBeg

if >0, then calculate the average

Definition at line 245 of file hrf-sa-train.h.

§ m_nPrintPerIter

int hrf::SAtrain::m_nPrintPerIter

output the LL per iteration, if ==-1, the disable

Definition at line 249 of file hrf-sa-train.h.

§ m_zeta_gap

double hrf::SAtrain::m_zeta_gap

Definition at line 242 of file hrf-sa-train.h.


The documentation for this class was generated from the following files: