TRF Language Model
wb-solve.h
Go to the documentation of this file.
1 // You may obtain a copy of the License at
2 //
3 // http://www.apache.org/licenses/LICENSE-2.0
4 //
5 // Unless required by applicable law or agreed to in writing, software
6 // distributed under the License is distributed on an "AS IS" BASIS,
7 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 // See the License for the specific language governing permissions and
9 // limitations under the License.
10 //
11 // Copyright 2014-2015 Tsinghua University
12 // Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13 //
14 // All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15 // license declaration. Different coding language may use different comment styles.
16 
24 #pragma once
25 #include "wb-win.h"
26 #include "wb-log.h"
27 #include "wb-vector.h"
28 #include <algorithm>
29 
30 namespace wb
31 {
32  class Solve;
39  class Func
40  {
41  friend class Solve;
42  friend class LBFGS;
43  protected:
46  public:
47 
48  Func(int nParamNum = 0):m_pSolve(NULL) { SetParamNum(nParamNum); };
50  void SetParamNum(int n) { m_nParamNum = n; }
52  int GetParamNum() const { return m_nParamNum; }
54  virtual void SetParam(double *pdParams) = 0;
56  virtual double GetValue() = 0;
58  virtual void GetGradient(double *pdGradient) = 0;
59  static const int cn_exvalue_max_num = 100;
60 
66  virtual int GetExtraValues(int k, double *pdValues) { return 0; }
67  };
68 
69 #define lout_Solve wb::lout<<m_pAlgorithmName<<" "
70 
77  class Solve
78  {
79  protected:
80  const char *m_pAlgorithmName;
81  public:
83 
84  double *m_pdRoot;
85 
86  int m_nIterNum;
87  int m_nIterMin;
88  int m_nIterMax;
89 
90  double m_dSpendMinute;
91 
92  double m_dStop;
93  double m_dGain;
94 
95  public:
96  Solve(Func *pfunc = NULL, double dtol=1e-5)
97  {
98  m_pAlgorithmName = "[Solve]";
99  m_pfunc = pfunc;
100  m_pfunc->m_pSolve = this;
101  m_pdRoot = NULL;
102 
103  m_nIterNum = 0;
104  m_nIterMin = 1;
105  m_nIterMax = 10000;
106 
107  m_dStop = dtol;
108  m_dGain = 0;
109  }
110 
112  virtual bool Run(const double *pInitParams = NULL);
114  virtual void IterInit() {};
116 
122  virtual void ComputeDir(int k, const double *pdParam, const double *pdGradient, double *pdDir);
124 
131  virtual double LineSearch(double *pdDir, double dValue, const double *pdParam, const double *pdGradient);
133  virtual void Update(double *pdParam, const double *pdDir, double dStep);
135  virtual bool StopDecision(int k, double dValue, const double *pdGradient);
136 
137  public:
139  static double VecProduct(const double *pdVec1, const double *pdVec2, int nSize);
141  static double VecNorm(const double *pdVec, int nSize);
143  static double VecDist(const double *pdVec1, const double *pdVec2, int nSize);
144  };
145 
146  /*
147  * \class LBFGS
148  * \brief LBFGS method
149  */
150  class LBFGS : public Solve
151  {
153  typedef struct {
154  double *s;
155  double *y;
156  } sy;
157  protected:
161  double *m_pd_s, *m_pd_y;
162 
164  double *m_pdPrevParam;
165  double *m_pdAlpha;
166 
167  public:
169  LBFGS(Func *pfunc = NULL, double dtol = 1e-5) :Solve(pfunc, dtol)
170  {
171  m_nLimitiedNum = 8;
172  m_pCirQueueBuf = NULL;
173  m_nCirQueueBufTail = 0;
174  m_pd_s = NULL;
175  m_pd_y = NULL;
176 
177  m_pdPrevGradient = NULL;
178  m_pdPrevParam = NULL;
179  m_pdAlpha = NULL;
180  }
183  {
184  SAFE_DELETE_ARRAY(m_pdPrevParam);
185  SAFE_DELETE_ARRAY(m_pdPrevGradient);
186  SAFE_DELETE_ARRAY(m_pdAlpha);
187 
188  CirQueueBuf_Release();
189  }
191  virtual void IterInit();
193  virtual void ComputeDir(int k, const double *pdParam, const double *pdGradient, double *pdDir);
195  void CirQueueBuf_Init();
197  void CirQueueBuf_Release();
199  void CirQueueBuf_Prev(int i, double *&pd_s, double *&pd_y);
201  void CirQueueBuf_In(double *&pd_s, double *&pd_y);
202  };
203 }
double * m_pdPrevParam
parameter on previous iteration
Definition: wb-solve.h:164
the base class of all the solve classes, and provide a gradient descent algorithm.
Definition: wb-solve.h:77
int m_nIterMin
minium iteration number
Definition: wb-solve.h:87
sy * m_pCirQueueBuf
the buffer of circular queue to store s_k = x_k - x_(k-1) and y_k = g_k - g_{k-1} ...
Definition: wb-solve.h:159
LBFGS(Func *pfunc=NULL, double dtol=1e-5)
constructor
Definition: wb-solve.h:169
virtual void SetParam(double *pdParams)=0
set the parameter.
virtual void IterInit()
initial the iteration, for derivation.
Definition: wb-solve.h:114
virtual void GetGradient(double *pdGradient)=0
calculate the gradient g(x)
a definition of a class Log, which can output to the cmd window and the log file simultaneously. In wb-log.cpp, there are a Log variable "lout", which can be directly used just like "cout". For example:
int m_nIterNum
current iteration number, iter form m_nIterMin to m_nIterMax
Definition: wb-solve.h:86
virtual double GetValue()=0
calculate the function value f(x)
Provide the toolkits for cmd window of window platform.
#define SAFE_DELETE_ARRAY(p)
Definition: wb-vector.h:50
Func(int nParamNum=0)
Definition: wb-solve.h:48
int m_nIterMax
maximum iteration number
Definition: wb-solve.h:88
int m_nLimitiedNum
limited number, i.e. m
Definition: wb-solve.h:158
pFunc Run(bInit)
int GetParamNum() const
get the paremeter number
Definition: wb-solve.h:52
double * m_pdRoot
save the root of the function
Definition: wb-solve.h:84
double m_dGain
itera step. ==0 means using the line search .
Definition: wb-solve.h:93
Func * m_pfunc
pointer to the function
Definition: wb-solve.h:82
double m_dStop
stop threshold
Definition: wb-solve.h:92
double * m_pdAlpha
auxillary factor in ComputeDir
Definition: wb-solve.h:165
const char * m_pAlgorithmName
the algorithm name.
Definition: wb-solve.h:80
static const int cn_exvalue_max_num
Definition: wb-solve.h:59
virtual int GetExtraValues(int k, double *pdValues)
calculate extra values which will be print at each iteration
Definition: wb-solve.h:66
int m_nParamNum
the parameter number
Definition: wb-solve.h:45
Solve * m_pSolve
Save the solve pointor.
Definition: wb-solve.h:44
double * m_pd_y
current s_k = x_k - x_{k-1} and y_k
Definition: wb-solve.h:161
double * m_pdPrevGradient
gradient on previous iteration
Definition: wb-solve.h:163
void SetParamNum(int n)
setting the parameter number
Definition: wb-solve.h:50
int m_nCirQueueBufTail
queue tail
Definition: wb-solve.h:160
the objective function, used to derive
Definition: wb-solve.h:39
~LBFGS()
destructor
Definition: wb-solve.h:182
double VecNorm(T *pVec, int len)
[Vec-function] sqrt(v*v^T);
Definition: wb-vector.h:571
double m_dSpendMinute
record the iteration spend time��minute��
Definition: wb-solve.h:90
Solve(Func *pfunc=NULL, double dtol=1e-5)
Definition: wb-solve.h:96
define all the code written by Bin Wang.
Definition: wb-file.cpp:21
Defination of simple dynamic array/stack/queue and so on.