TRF Language Model
trf-alg.cpp
Go to the documentation of this file.
1 // You may obtain a copy of the License at
2 //
3 // http://www.apache.org/licenses/LICENSE-2.0
4 //
5 // Unless required by applicable law or agreed to in writing, software
6 // distributed under the License is distributed on an "AS IS" BASIS,
7 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 // See the License for the specific language governing permissions and
9 // limitations under the License.
10 //
11 // Copyright 2014-2015 Tsinghua University
12 // Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13 //
14 // All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15 // license declaration. Different coding language may use different comment styles.
16 
17 
18 #include "trf-alg.h"
19 
20 namespace trf
21 {
23  {
24  for (int i = 0; i < m_aAlpha.GetNum(); i++)
26  for (int i = 0; i < m_aBeta.GetNum(); i++)
27  SAFE_DELETE(m_aBeta[i]);
28  m_aAlpha.Clean();
29  m_aBeta.Clean();
30  }
31  void Algfb::Prepare(int nLen, int nOrder, int nValueLimit)
32  {
33  m_nLen = nLen;
34  m_nOrder = nOrder;
35  m_nValueLimit = nValueLimit;
36 
37  for (int i = 0; i < m_aAlpha.GetNum(); i++)
39  for (int i = 0; i < m_aBeta.GetNum(); i++)
40  SAFE_DELETE(m_aBeta[i]);
41 
42  int nClusterNum = nLen - nOrder + 1;
43  if (nLen < nOrder)
44  return;
45 
46  m_aAlpha.SetNum(nClusterNum);
47  m_aBeta.SetNum(nClusterNum);
48 
49  for (int i = 0; i < nClusterNum; i++) {
50  m_aAlpha[i] = new Msg(m_nOrder - 1, m_nValueLimit);
51  m_aBeta[i] = new Msg(m_nOrder - 1, m_nValueLimit);
52  }
53  }
54  void Algfb::ForwardBackward(int nLen, int nOrder, int nValueLimit)
55  {
56  Prepare(nLen, nOrder, nValueLimit);
57  int nClusterNum = nLen - nOrder + 1;
58 
59  if (nClusterNum <= 0)
60  return;
61 
62  Vec<int> nodeSeq(nLen);
63 
64  // Forward
65  m_aAlpha[0]->Fill(0);
66  for (int i = 1; i < nClusterNum; i++) {
67  VecIter iter(nodeSeq.GetBuf() + i, m_nOrder - 1, 0, m_nValueLimit - 1);
68  while (iter.Next()) {
69 
70  LogP dLogSum = LogP_zero;
71  for (nodeSeq[i - 1] = 0; nodeSeq[i - 1] < m_nValueLimit; nodeSeq[i - 1]++) {
72  LogP temp = ClusterSum(nodeSeq.GetBuf(), nLen, i - 1, m_nOrder)
73  + m_aAlpha[i - 1]->Get(nodeSeq.GetBuf() + i - 1, m_nOrder - 1);
74  dLogSum = Log_Sum(dLogSum, temp);
75  }
76 
77  m_aAlpha[i]->Get(nodeSeq.GetBuf() + i, m_nOrder - 1) = dLogSum;
78  }
79  }
80 
81  // Backward
82  m_aBeta[nClusterNum - 1]->Fill(0);
83  for (int i = nClusterNum - 2; i >= 0; i--) {
84  VecIter iter(nodeSeq.GetBuf() + i + 1, m_nOrder - 1, 0, m_nValueLimit - 1);
85  while (iter.Next()) {
86 
87  LogP dLogSum = LogP_zero;
88  for (nodeSeq[i + m_nOrder] = 0; nodeSeq[i + m_nOrder] < m_nValueLimit; nodeSeq[i + m_nOrder]++) {
89  LogP temp = ClusterSum(nodeSeq.GetBuf(), nLen, i + 1, m_nOrder)
90  + m_aBeta[i + 1]->Get(nodeSeq.GetBuf() + i + 2, m_nOrder - 1);
91  dLogSum = Log_Sum(dLogSum, temp);
92  }
93 
94  m_aBeta[i]->Get(nodeSeq.GetBuf() + i + 1, m_nOrder - 1) = dLogSum;
95  }
96  }
97 
98  }
99  LogP Algfb::GetMarginalLogProb(int nPos, int *pSubSeq, int nSubLen, double logz /* = 0 */)
100  {
101  // Forward-backward need be calculate
102 
103  if (nPos + nSubLen > m_nLen) {
104  lout_error("[Model] GetMarginalLogProb: nPos(" << nPos << ")+nOrder(" << nSubLen << ") > seq.len(" << m_nLen << ")!!");
105  }
106 
107  LogP dSum = LogP_zero; // 0 prob
108  Vec<int> nseq(m_nLen); //save the sequence
109 
110  // if the length is very small
111  // then ergodic the sequence of length
112  if (m_nLen < m_nOrder)
113  {
114  VecIter iter(nseq.GetBuf(), m_nLen, 0, m_nValueLimit - 1);
115  while (iter.Next()) {
116  if (nseq.GetSub(nPos, nSubLen) == VecShell<int>(pSubSeq, nSubLen)) {
117  dSum = Log_Sum(dSum, ClusterSum(nseq.GetBuf(), m_nLen, 0, m_nLen));
118  }
119  }
120  }
121  else if (nSubLen == m_nOrder) {
122  VecShell<int>(nseq.GetBuf()+nPos, nSubLen) = VecShell<int>(pSubSeq, nSubLen);
123  dSum = ClusterSum(nseq.GetBuf(), m_nLen, nPos, m_nOrder)
124  + m_aAlpha[nPos]->Get(pSubSeq, m_nOrder - 1)
125  + m_aBeta[nPos]->Get(pSubSeq + 1, m_nOrder - 1);
126  }
127  else {
128  // Choose a cluster
129  if (nPos <= m_nLen - m_nOrder) { // choose the cluster nPos
130  VecShell<int>(nseq.GetBuf() + nPos, nSubLen) = VecShell<int>(pSubSeq,nSubLen);
131  VecIter iter(nseq.GetBuf() + nPos + nSubLen, m_nOrder - nSubLen, 0, m_nValueLimit - 1);
132  while (iter.Next()) {
133  dSum = Log_Sum(dSum,
134  ClusterSum(nseq.GetBuf(), m_nLen, nPos, m_nOrder)
135  + m_aAlpha[nPos]->Get(nseq.GetBuf() + nPos, m_nOrder - 1)
136  + m_aBeta[nPos]->Get(nseq.GetBuf() + nPos + 1, m_nOrder - 1));
137  }
138  }
139  else { // choose the last cluster
140  int nCluster = m_nLen - m_nOrder; // cluster position
141  VecIter iter(nseq.GetBuf() + nCluster, m_nOrder, 0, m_nValueLimit - 1);
142  while (iter.Next()) {
143  if (nseq.GetSub(nPos, nSubLen)==VecShell<int>(pSubSeq, nSubLen)) {
144  dSum = Log_Sum(dSum,
145  ClusterSum(nseq.GetBuf(), m_nLen, nCluster, m_nOrder)
146  + m_aAlpha[nCluster]->Get(nseq.GetBuf() + nCluster, m_nOrder - 1)
147  + m_aBeta[nCluster]->Get(nseq.GetBuf() + nCluster + 1, m_nOrder - 1));
148  }
149  }
150  }
151  }
152 
153  return dSum - logz;
154  }
156  {
157  int nIterDim = min(m_nOrder, m_nLen);
158  Vec<int> nodeSeq(m_nLen);
159  int nSumPos = 0;
160 
162  LogP logSum = LogP_zero;
163  VecIter iter(nodeSeq.GetBuf() + nSumPos, nIterDim, 0, m_nValueLimit - 1);
164  while (iter.Next()) {
165  LogP temp = 0;
166  if (nIterDim == m_nLen) { // no cluster
167  temp = ClusterSum(nodeSeq.GetBuf(), m_nLen, nSumPos, nIterDim);
168  }
169  else {
170  temp = ClusterSum(nodeSeq.GetBuf(), m_nLen, nSumPos, m_nOrder)
171  + m_aAlpha[nSumPos]->Get(nodeSeq.GetBuf() + nSumPos, m_nOrder - 1)
172  + m_aBeta[nSumPos]->Get(nodeSeq.GetBuf() + nSumPos + 1, m_nOrder - 1);
173  }
174 
175  logSum = Log_Sum(logSum, temp);
176  }
177 
178  return logSum;
179  }
180 
181  /************************************************************************/
182  /* class Msg */
183  /************************************************************************/
184 
185  Msg::Msg(int nMsgDim, int nSize)
186  {
187  m_dim = nMsgDim;
188  //m_pmodel = pm;
189  m_size = nSize;//m_pmodel->GetEncodeNodeLimit();
190  int totalsize = pow(m_size, m_dim);
191  m_pbuf = new float[totalsize];
192 
193  Fill(0);
194  }
196  {
197  SAFE_DELETE_ARRAY(m_pbuf);
198  }
199  void Msg::Fill(float v)
200  {
201  int totalsize = pow(m_size, m_dim);
202  for (int i = 0; i < totalsize; i++) {
203  m_pbuf[i] = v;
204  }
205  }
206  void Msg::Copy(Msg &m)
207  {
208  if (GetBufSize() != m.GetBufSize()) {
209  SAFE_DELETE_ARRAY(m_pbuf);
210  m_dim = m.m_dim;
211  m_size = m.m_size;
212  m_pbuf = new float[GetBufSize()];
213  }
214 
215  memcpy(m_pbuf, m.m_pbuf, sizeof(m_pbuf[0]) * GetBufSize());
216  }
217  float& Msg::Get(int *pIdx, int nDim)
218  {
219  lout_assert(nDim == m_dim);
220 
221  int nIndex = pIdx[0];
222  for (int i = 0; i < nDim - 1; i++)
223  {
224  nIndex = nIndex * m_size + pIdx[i + 1];
225  }
226  return m_pbuf[nIndex];
227  }
228 
229 
230 
231  /************************************************************************/
232  /* VecIter */
233  /************************************************************************/
234  VecIter::VecIter(int *p, int nDim, int nMin, int nMax)
235  {
236  m_pBuf = p;
237  m_nDim = nDim;
238  m_nMin = nMin;
239  m_nMax = nMax;
240  Reset();
241  }
243  {
244  for (int i = 0; i < m_nDim; i++)
245  m_pBuf[i] = m_nMin;
246  m_pBuf[0]--;
247  }
249  {
250  m_pBuf[0]++;
251  for (int i = 0; i < m_nDim - 1; i++) {
252  if (m_pBuf[i] > m_nMax) {
253  m_pBuf[i + 1]++;
254  m_pBuf[i] = m_nMin;
255  }
256  else {
257  break;
258  }
259  }
260 
261  return m_pBuf[m_nDim - 1] <= m_nMax;
262  }
263 }
const float LogP_zero
Definition: trf-def.h:30
#define SAFE_DELETE(p)
memory release
Definition: wb-vector.h:49
#define lout_error(x)
Definition: wb-log.h:183
#define lout_assert(p)
Definition: wb-log.h:185
int m_nLen
the sequence length.
Definition: trf-alg.h:38
Msg(int nMsgDim, int nSize)
Definition: trf-alg.cpp:185
LogP Log_Sum(LogP x, LogP y)
Definition: trf-def.h:40
int GetBufSize() const
Definition: trf-alg.h:76
VecShell< T > GetSub(int nPos, int nLen)
Definition: wb-mat.h:71
void Copy(Msg &m)
Definition: trf-alg.cpp:206
void Prepare(int nLen, int nOrder, int nValueLimit)
prepare
Definition: trf-alg.cpp:31
void Reset()
Definition: trf-alg.cpp:242
double LogP
Definition: trf-def.h:27
int m_nValueLimit
the max-value at each position
Definition: trf-alg.h:39
int m_nOrder
the order, i.e. the node number at each cluster {x_1,x_2,...,x_n}
Definition: trf-alg.h:37
Definition: wb-mat.h:29
virtual LogP ClusterSum(int *pSeq, int nLen, int nPos, int nOrder)=0
This function need be derived. Calcualte the log probability of each cluster.
Array< Msg * > m_aAlpha
the forward message
Definition: trf-alg.h:40
void Fill(float v)
Definition: trf-alg.cpp:199
#define SAFE_DELETE_ARRAY(p)
Definition: wb-vector.h:50
T * GetBuf() const
Definition: wb-mat.h:68
LogP GetLogSummation()
Get the summation over the sequence, corresponding to the log normalization constants &#39;logZ&#39;...
Definition: trf-alg.cpp:155
float & Get(int *pIdx, int nDim)
Definition: trf-alg.cpp:217
void ForwardBackward(int nLen, int nOrder, int nValueLimit)
forward-backward calculation
Definition: trf-alg.cpp:54
pFunc Reset & m
VecIter(int *p, int nDim, int nMin, int nMax)
Definition: trf-alg.cpp:234
LogP GetMarginalLogProb(int nPos, int *pSubSeq, int nSubLen, double logz=0)
Get the marginal probability. &#39;logz&#39; is the input of the log normalization constants.
Definition: trf-alg.cpp:99
bool Next()
Definition: trf-alg.cpp:248
aLL Fill(0)
Array< Msg * > m_aBeta
the backward message
Definition: trf-alg.h:41
Definition: trf-alg.cpp:20