TRF Language Model
wb-mat.h
Go to the documentation of this file.
1 // You may obtain a copy of the License at
2 //
3 // http://www.apache.org/licenses/LICENSE-2.0
4 //
5 // Unless required by applicable law or agreed to in writing, software
6 // distributed under the License is distributed on an "AS IS" BASIS,
7 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 // See the License for the specific language governing permissions and
9 // limitations under the License.
10 //
11 // Copyright 2014-2015 Tsinghua University
12 // Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13 //
14 // All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15 // license declaration. Different coding language may use different comment styles.
16 
17 
18 #ifndef _WB_MAT_H_
19 #define _WB_MAT_H_
20 #include "wb-vector.h"
21 #include "wb-log.h"
22 #include "wb-file.h"
23 
24 namespace wb
25 {
26 
27  template <class T> class VecShell;
28  template <class T> class MatShell;
29  template <class T> class Vec;
30  template <class T> class Mat;
31 
32  template <class T> bool VecEqual(VecShell<T> &v1, VecShell<T> &v2);
33  template <class T> void VecAdd(VecShell<T> &res, VecShell<T> &v1, VecShell<T> &v2);
34  template <class T> T VecDot(VecShell<T> &v1, VecShell<T> &v2);
35  template <class T> T MatVec2(MatShell<T> &m, VecShell<T> &v1, VecShell<T> &v2);
36  template <class T> T MatVec2(MatShell<T> &m, VecShell<T> &v1, VecShell<T> &v2);
37 
38 
48  template <class T>
53  class VecShell
54  {
56  friend bool VecEqual<>(VecShell<T> &v1, VecShell<T> &v2);
57  friend void VecAdd<>(VecShell<T> &res, VecShell<T> &v1, VecShell<T> &v2);
58  friend T VecDot<>(VecShell<T> &v1, VecShell<T> &v2);
59  friend T MatVec2<>(MatShell<T> &m, VecShell<T> &v1, VecShell<T> &v2);
60 
61  protected:
62  T *m_pBuf;
63  int m_nSize;
64  public:
65  VecShell() : m_pBuf(NULL), m_nSize(0) {};
66  VecShell(T *p, int size) : m_pBuf(p), m_nSize(size) {};
67  void Fill(T v);
68  T* GetBuf() const { return m_pBuf; }
69  int GetSize() const { return m_nSize; }
70  int ByteSize() const { return GetSize() * sizeof(T); }
71  VecShell<T> GetSub(int nPos, int nLen) { return VecShell<T>(m_pBuf + nPos, nLen); }
72  void Reset(T *p, int size) { m_pBuf = p; m_nSize = size; }
73  T& operator [] (int i);
74  void operator = (VecShell v);
75  void operator += (VecShell v);
76  void operator -= (VecShell v);
77  void operator *= (T n);
78  void operator /= (T n);
79  bool operator == (VecShell v);
80  //operator T*() const { return m_pBuf; }
81  T Sum() {
82  T sum = 0;
83  for (int i = 0; i < m_nSize; i++)
84  sum += m_pBuf[i];
85  return sum;
86  }
87 
88  };
89 
90  template <class T>
95  class Vec : public VecShell<T>
96  {
97  private:
98  int m_nBufSize;
99  public:
100  Vec():m_nBufSize(0) { Reset(); }
101  Vec(int size):m_nBufSize(0) { Reset(size); }
102  ~Vec() { Reset(); }
103  void Reset(int size = 0);
104  void Copy(VecShell<T> v);
105  };
106 
107 
108  template <class T>
113  class MatShell
114  {
115  friend T MatVec2<>(MatShell<T> &m, VecShell<T> &v1, VecShell<T> &v2);
116  protected:
117  T *m_pBuf;
118  int m_nRow;
119  int m_nCol;
120  public:
121  MatShell() : m_pBuf(NULL), m_nRow(0), m_nCol(0) {};
122  MatShell(T *pbuf, int row, int col) :m_pBuf(pbuf), m_nRow(row), m_nCol(col) {};
123  void Fill(T v);
124  T* GetBuf() const { return m_pBuf; }
125  T& Get(unsigned int i, unsigned int j) { return m_pBuf[i*m_nCol + j]; }
126  int GetSize() const { return m_nRow*m_nCol; }
127  int ByteSize() const { return GetSize() * sizeof(T); }
128  int GetRow() const { return m_nRow; }
129  int GetCol() const { return m_nCol; }
130  void Reset(T *pbuf, int row, int col) { m_pBuf = pbuf; m_nRow = row; m_nCol = col; }
131  void Read(File &file);
132  void Write(File &file);
133  VecShell<T> operator [] (int i) { return VecShell<T>(m_pBuf + i*m_nCol, m_nCol);}
134  operator T* () { return m_pBuf; }
135  bool operator== (MatShell &m);
136  };
137 
138  template <class T>
143  class Mat : public MatShell<T>
144  {
145  private:
146  int m_nBufSize;
147  public:
148  Mat() : m_nBufSize(0) { Reset(); }
149  Mat(int row, int col) : m_nBufSize(0) { Reset(row, col); }
150  ~Mat() { Reset(); }
151  void Reset(int row = 0, int col = 0);
152  void Copy(MatShell<T> &m);
153  };
154 
155  template <class T>
161  {
162  protected:
163  T* m_pBuf;
164  int m_nXDim;
165  int m_nYDim;
166  int m_nZDim;
167  public:
168  Mat3dShell() :m_pBuf(NULL), m_nXDim(0), m_nYDim(0), m_nZDim(0) {};
169  Mat3dShell(T* p, int xdim, int ydim, int zdim) :m_pBuf(p), m_nXDim(xdim), m_nYDim(ydim), m_nZDim(zdim) {};
170  void Fill(T v);
171  T* GetBuf() const { return m_pBuf; }
172  T& Get(int x, int y, int z) { return m_pBuf[x*m_nYDim*m_nZDim + y*m_nZDim + z]; }
173  int GetSize() const { return m_nXDim * m_nYDim * m_nZDim; }
174  int ByteSize() const { return GetSize() * sizeof(T); }
175  int GetXDim() const { return m_nXDim; }
176  int GetYDim() const { return m_nYDim; }
177  int GetZDim() const { return m_nZDim; }
178  void Reset(T* p, int xdim, int ydim, int zdim) { m_pBuf = p; m_nXDim = xdim; m_nYDim = ydim; m_nZDim = zdim; }
179  MatShell<T> operator[] (int x) { return MatShell<T>(m_pBuf + x*m_nYDim*m_nZDim, m_nYDim, m_nZDim); }
180  void Write(File &file);
181  void Read(File &file);
182  bool operator== (Mat3dShell &m);
183  };
184 
185  template <class T>
190  class Mat3d : public Mat3dShell<T>
191  {
192  private:
193  int m_nBufSize;
194  public:
195  Mat3d() : m_nBufSize(0) { Reset(); }
196  Mat3d(int xdim, int ydim, int zdim) : m_nBufSize(0) { Reset(xdim, ydim, zdim); }
197  ~Mat3d() { Reset(); }
198  void Reset(int xdim=0, int ydim=0, int zdim=0);
199  void Copy(Mat3dShell<T> &m);
200  };
201 
202 
203 
204  /************************************************************************/
205  /* mat * vec / vec * vec */
206  /************************************************************************/
207 
208  template <class T>
211  {
212  if (v1.GetSize() != v2.GetSize())
213  return false;
214  for (int i = 0; i < v1.GetSize(); i++) {
215  if (v1[i] != v2[i]) {
216  return false;
217  }
218  }
219  return true;
220  }
221  template <class T>
224  {
225  if (v1.GetSize() != v2.GetSize() || res.GetSize() != v1.GetSize()) {
226  lout_error("[VecAdd] Vec Size are not equal: v1.size="
227  << v1.GetSize() << " v2.size=" << v2.GetSize() << " res.size=" << res.GetSize());
228  }
229  for (int i = 0; i < res.GetSize(); i++) {
230  res.m_pBuf[i] = v1.m_pBuf[i] + v2.m_pBuf[i];
231  }
232  }
233  template <class T>
236  {
237 #ifdef _DEBUG
238  if (v1.GetSize() != v2.GetSize()) {
239  lout_error("[VecDot] v1.size(" << v1.GetSize() << ") != v2.size(" << v2.GetSize() << ")");
240  }
241 #endif
242  T sum = 0;
243  for (int i = 0; i < v1.GetSize(); i++) {
244  sum += v1.m_pBuf[i] * v2.m_pBuf[i];
245  }
246  return sum;
247  }
248 
249  template <class T>
252  {
253 #ifdef _DEBUG
254  if (v1.GetSize() != m.GetRow()) {
255  lout_error("[MatVec2] v1.size(" << v1.GetSize() << ") != m.row(" << m.GetRow() << ")");
256  }
257  if (v2.GetSize() != m.GetCol()) {
258  lout_error("[MatVec2] m.col(" << m.GetCol() << ") != v2.size(" << v2.GetSize() << ")");
259  }
260 #endif
261  T sum = 0;
262  for (int i = 0; i < m.GetRow(); i++) {
263  if (v1.m_pBuf[i] == 0)
264  continue;
265  for (int j = 0; j < m.GetCol(); j++) {
266  if (v2.m_pBuf[j] == 0)
267  continue;
268  sum += v1.m_pBuf[i] * m.Get(i,j) * v2.m_pBuf[j];
269  }
270  }
271  return sum;
272  }
273 
274  /************************************************************************/
275  /* VecShell */
276  /************************************************************************/
277 
278  template <class T>
280  {
281  if (!m_pBuf) {
282  return;
283  }
284  for (int i = 0; i < m_nSize; i++) {
285  m_pBuf[i] = v;
286  }
287  }
288  template <class T>
290  {
291 #ifdef _DEBUG
292  if (!m_pBuf) {
293  lout_error("[Vec] op[]: buffer = NULL");
294  }
295  if (i < 0 || i >= m_nSize) {
296  lout_error("[Vec] op[] index i(" << i << ") over the size(" << m_nSize << ")");
297  }
298 #endif
299  return m_pBuf[i];
300  }
301  template <class T>
303  {
304  if (v.GetSize() != GetSize()) {
305  lout_error("[VecShell] op=: the size is not equal (" << v.GetSize() << ")(" << GetSize() << ")");
306  }
307  memcpy(m_pBuf, v.m_pBuf, sizeof(T)*m_nSize);
308  }
309  template <class T>
311  {
312  if (v.GetSize() != GetSize()) {
313  lout_error("[VecShell] op+=: the size is not equal (" << v.GetSize() << ")(" << GetSize() << ")");
314  }
315  for (int i = 0; i < GetSize(); i++) {
316  m_pBuf[i] += v.m_pBuf[i];
317  }
318  }
319  template <class T>
321  {
322  if (v.GetSize() != GetSize()) {
323  lout_error("[VecShell] op+=: the size is not equal (" << v.GetSize() << ")(" << GetSize() << ")");
324  }
325  for (int i = 0; i < GetSize(); i++) {
326  m_pBuf[i] -= v.m_pBuf[i];
327  }
328  }
329  template <class T>
331  {
332  for (int i = 0; i < GetSize(); i++) {
333  m_pBuf[i] *= n;
334  }
335  }
336  template <class T>
338  {
339  for (int i = 0; i < GetSize(); i++) {
340  m_pBuf[i] /= n;
341  }
342  }
343  template <class T>
345  {
346  if (m_nSize != v.m_nSize)
347  return false;
348  for (int i = 0; i < m_nSize; i++) {
349  if (m_pBuf[i] != v[i]) {
350  return false;
351  }
352  }
353  return true;
354  }
355 
356  /************************************************************************/
357  /* Vec */
358  /************************************************************************/
359  template <class T>
360  void Vec<T>::Reset(int size /* = 0 */)
361  {
362  if (size == 0) {
363  // Clean buffer
364  SAFE_DELETE_ARRAY(this->m_pBuf);
365  this->m_nSize = 0;
366  this->m_nBufSize = 0;
367  return;
368  }
369 
370  if (size <= this->m_nBufSize) { // donot re-alloc memory
371  this->m_nSize = size;
372  }
373  else { // Re-alloc
374  T *p = new T[size];
375  if (p == NULL) {
376  lout_error("[Vec] Reset: new buffer error!");
377  }
378  memcpy(p, this->m_pBuf, sizeof(T)*this->m_nSize);
379  SAFE_DELETE_ARRAY(this->m_pBuf);
380  this->m_pBuf = p;
381  this->m_nSize = size;
382  this->m_nBufSize = size;
383  }
384  }
385  template <class T>
387  {
388  Reset(v.GetSize());
389  memcpy(this->m_pBuf, v.GetBuf(), sizeof(T)*v.GetSize());
390  }
391 
392  /************************************************************************/
393  /* MatShell */
394  /************************************************************************/
395 
396  template <class T>
398  {
399  if (!m_pBuf) {
400  lout_error("[Mat] Fill: buffer = NULL");
401  }
402  for (int i = 0; i < m_nRow*m_nCol; i++) {
403  m_pBuf[i] = v;
404  }
405  }
406  template <class T>
408  {
409  if (m_nRow != m.m_nRow || m_nCol != m.m_nCol)
410  return false;
411  for (int i = 0; i < m_nRow*m_nCol; i++) {
412  if (m_pBuf[i] != m.m_pBuf[i])
413  return false;
414  }
415  return true;
416  }
417  template <class T>
419  {
420  ofstream os(file.fp);
421  for (int i = 0; i < m_nRow; i++) {
422  for (int j = 0; j < m_nCol; j++) {
423  os << Get(i, j) << " ";
424  }
425  os << endl;
426  }
427  }
428  template <class T>
430  {
431  ifstream is(file.fp);
432  for (int i = 0; i < m_nRow; i++) {
433  for (int j = 0; j < m_nCol; j++) {
434  is >> Get(i, j);
435  }
436  file.Scanf("\n");
437  }
438  }
439 
440  /************************************************************************/
441  /* Mat */
442  /************************************************************************/
443 
444  template <class T>
445  void Mat<T>::Reset(int row /* = 0 */, int col /* = 0 */)
446  {
447  if (row * col == 0) {
448  // Clean buffer
449  SAFE_DELETE_ARRAY(this->m_pBuf);
450  this->m_nRow = 0;
451  this->m_nCol = 0;
452  this->m_nBufSize = 0;
453  return;
454  }
455 
456  int size = row * col;
457  if (size <= this->m_nBufSize) {
458  this->m_nRow = row;
459  this->m_nCol = col;
460  }
461  else {
462  T *p = new T[size];
463  if (p == NULL) {
464  lout_error("[Mat] Reset: new buffer error!");
465  }
466  memcpy(p, this->m_pBuf, sizeof(T)*this->m_nRow*this->m_nCol);
467  SAFE_DELETE_ARRAY(this->m_pBuf);
468  this->m_pBuf = p;
469  this->m_nRow = row;
470  this->m_nCol = col;
471  this->m_nBufSize = size;
472  }
473  }
474  template <class T>
476  {
477  Reset(m.GetRow(), m.GetCol());
478  memcpy(this->m_pBuf, m.GetBuf(), sizeof(T)*this->m_nRow*this->m_nCol);
479  }
480 
481  template <class T>
483  {
484  if (!m_pBuf) {
485  lout_error("[Mat3d] Fill: buffer = NULL");
486  }
487  for (int i = 0; i < GetSize(); i++) {
488  m_pBuf[i] = v;
489  }
490  }
491  template <class T>
493  {
494  ofstream os(file.fp);
495  int x, y, z;
496  for (x = 0; x < m_nXDim; x++) {
497  for (y = 0; y < m_nYDim; y++) {
498  os << "[";
499  for (z = 0; z < m_nZDim-1; z++)
500  os<<Get(x, y, z)<<" ";
501  os << Get(x, y, z) << "]";
502  }
503  os << endl;
504  }
505  }
506  template <class T>
508  {
509  ifstream is(file.fp);
510  int x, y, z;
511  char c;
512  for (x = 0; x < m_nXDim; x++) {
513  for (y = 0; y < m_nYDim; y++) {
514  is >> c;
515  for (z = 0; z < m_nZDim - 1; z++)
516  is >> Get(x, y, z);
517  is >> Get(x, y, z) >> c;
518  }
519  // ����
520  file.Scanf("\n");
521  }
522  }
523  template <class T>
525  {
526  if (m_nXDim != m.m_nXDim || m_nYDim != m.m_nYDim || m_nZDim != m.m_nZDim)
527  return false;
528 
529  for (int i = 0; i < GetSize(); i++) {
530  if (m_pBuf[i] != m.m_pBuf[i])
531  return false;
532  }
533  return true;
534  }
535  template <class T>
536  void Mat3d<T>::Reset(int xdim/* =0 */, int ydim/* =0 */, int zdim/* =0 */)
537  {
538  if (xdim*ydim*zdim == 0) {
539  // Clean buffer
540  SAFE_DELETE_ARRAY(this->m_pBuf);
541  this->m_nXDim = 0;
542  this->m_nYDim = 0;
543  this->m_nZDim = 0;
544  this->m_nBufSize = 0;
545  return;
546  }
547 
548  int size = xdim * ydim * zdim;
549  if (size <= this->m_nBufSize) {
550  this->m_nXDim = xdim;
551  this->m_nYDim = ydim;
552  this->m_nZDim = zdim;
553  }
554  else {
555  T *p = new T[size];
556  if (p == NULL) {
557  lout_error("[Mat] Reset: new buffer error!");
558  }
559  memcpy(p, this->m_pBuf, sizeof(T)*this->GetSize());
560  SAFE_DELETE_ARRAY(this->m_pBuf);
561  this->m_pBuf = p;
562  this->m_nXDim = xdim;
563  this->m_nYDim = ydim;
564  this->m_nZDim = zdim;
565  this->m_nBufSize = size;
566  }
567  }
568  template <class T>
570  {
571  Reset(m.GetXDim(), m.GetYDim(), m.GetZDim());
572  memcpy(this->m_pBuf, m.GetBuf(), sizeof(T)*m.GetSize());
573  }
575 }
576 
577 #endif
Vec(int size)
Definition: wb-mat.h:101
Definition: wb-mat.h:30
int GetSize() const
Definition: wb-mat.h:126
void Copy(Mat3dShell< T > &m)
Definition: wb-mat.h:569
VecShell()
Definition: wb-mat.h:65
T & Get(unsigned int i, unsigned int j)
Definition: wb-mat.h:125
T * GetBuf() const
Definition: wb-mat.h:171
MatShell(T *pbuf, int row, int col)
Definition: wb-mat.h:122
void Read(File &file)
Definition: wb-mat.h:507
void Write(File &file)
Definition: wb-mat.h:492
friend T VecDot(VecShell< T > &v1, VecShell< T > &v2)
calculate V*V
Definition: wb-mat.h:235
#define lout_error(x)
Definition: wb-log.h:183
void Fill(T v)
Definition: wb-mat.h:397
void operator/=(T n)
Definition: wb-mat.h:337
void Reset(int xdim=0, int ydim=0, int zdim=0)
Definition: wb-mat.h:536
int ByteSize() const
Definition: wb-mat.h:70
int ByteSize() const
Definition: wb-mat.h:174
T & Get(int x, int y, int z)
Definition: wb-mat.h:172
~Mat()
Definition: wb-mat.h:150
VecShell< T > GetSub(int nPos, int nLen)
Definition: wb-mat.h:71
bool operator==(MatShell &m)
Definition: wb-mat.h:407
int GetZDim() const
Definition: wb-mat.h:177
virtual int Scanf(const char *p_pMessage,...)
scanf
Definition: wb-file.cpp:132
T * m_pBuf
buf pointer
Definition: wb-mat.h:117
FILE * fp
file pointer
Definition: wb-file.h:97
int GetSize() const
Definition: wb-mat.h:173
void Reset(T *p, int size)
Definition: wb-mat.h:72
a definition of a class Log, which can output to the cmd window and the log file simultaneously. In wb-log.cpp, there are a Log variable "lout", which can be directly used just like "cout". For example:
Definition: wb-mat.h:29
~Mat3d()
Definition: wb-mat.h:197
Vec()
Definition: wb-mat.h:100
friend T MatVec2(MatShell< T > &m, VecShell< T > &v1, VecShell< T > &v2)
calculate V1*M*V2
Definition: wb-mat.h:251
friend void VecAdd(VecShell< T > &res, VecShell< T > &v1, VecShell< T > &v2)
calculate V + V
Definition: wb-mat.h:223
bool operator==(VecShell v)
Definition: wb-mat.h:344
void Fill(T v)
Definition: wb-mat.h:279
int ByteSize() const
Definition: wb-mat.h:127
file class.
Definition: wb-file.h:94
int GetSize() const
Definition: wb-mat.h:69
Mat(int row, int col)
Definition: wb-mat.h:149
int m_nCol
Definition: wb-mat.h:119
Mat3d(int xdim, int ydim, int zdim)
Definition: wb-mat.h:196
#define SAFE_DELETE_ARRAY(p)
Definition: wb-vector.h:50
void VecAdd(VecShell< T > &res, VecShell< T > &v1, VecShell< T > &v2)
calculate V + V
Definition: wb-mat.h:223
T * m_pBuf
buf pointer
Definition: wb-mat.h:62
T * GetBuf() const
Definition: wb-mat.h:68
void operator*=(T n)
Definition: wb-mat.h:330
friend bool VecEqual(VecShell< T > &v1, VecShell< T > &v2)
define the friend function
Definition: wb-mat.h:210
int GetXDim() const
Definition: wb-mat.h:175
~Vec()
Definition: wb-mat.h:102
bool operator==(Mat3dShell &m)
Definition: wb-mat.h:524
int GetYDim() const
Definition: wb-mat.h:176
void Read(File &file)
Definition: wb-mat.h:429
void Reset(int row=0, int col=0)
Definition: wb-mat.h:445
VecShell(T *p, int size)
Definition: wb-mat.h:66
void operator+=(VecShell v)
Definition: wb-mat.h:310
T & operator[](int i)
Definition: wb-mat.h:289
pFunc Reset & m
void Reset(int size=0)
Definition: wb-mat.h:360
Mat3dShell(T *p, int xdim, int ydim, int zdim)
Definition: wb-mat.h:169
T VecDot(VecShell< T > &v1, VecShell< T > &v2)
calculate V*V
Definition: wb-mat.h:235
void Write(File &file)
Definition: wb-mat.h:418
T MatVec2(MatShell< T > &m, VecShell< T > &v1, VecShell< T > &v2)
calculate V1*M*V2
Definition: wb-mat.h:251
bool VecEqual(VecShell< T > &v1, VecShell< T > &v2)
calculate V==V
Definition: wb-mat.h:210
define the file class
int GetCol() const
Definition: wb-mat.h:129
void Copy(MatShell< T > &m)
Definition: wb-mat.h:475
int GetRow() const
Definition: wb-mat.h:128
Mat3d()
Definition: wb-mat.h:195
Mat()
Definition: wb-mat.h:148
void operator-=(VecShell v)
Definition: wb-mat.h:320
int m_nRow
Definition: wb-mat.h:118
void Fill(T v)
Definition: wb-mat.h:482
void Copy(VecShell< T > v)
Definition: wb-mat.h:386
T * GetBuf() const
Definition: wb-mat.h:124
T Sum()
Definition: wb-mat.h:81
void Reset(T *p, int xdim, int ydim, int zdim)
Definition: wb-mat.h:178
define all the code written by Bin Wang.
Definition: wb-file.cpp:21
void Reset(T *pbuf, int row, int col)
Definition: wb-mat.h:130
Defination of simple dynamic array/stack/queue and so on.
int m_nSize
buf size
Definition: wb-mat.h:63
void operator=(VecShell v)
Definition: wb-mat.h:302