TRF Language Model
trf-vocab.h
Go to the documentation of this file.
1 // You may obtain a copy of the License at
2 //
3 // http://www.apache.org/licenses/LICENSE-2.0
4 //
5 // Unless required by applicable law or agreed to in writing, software
6 // distributed under the License is distributed on an "AS IS" BASIS,
7 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 // See the License for the specific language governing permissions and
9 // limitations under the License.
10 //
11 // Copyright 2014-2015 Tsinghua University
12 // Author: wb.th08@gmail.com (Bin Wang), ozj@tsinghua.edu.cn (Zhijian Ou)
13 //
14 // All h, cpp, cc, and script files (e.g. bat, sh, pl, py) should include the above
15 // license declaration. Different coding language may use different comment styles.
16 
17 
18 #pragma once
19 #include "trf-def.h"
20 
21 namespace trf
22 {
23  typedef int VocabID;
24  const int VocabID_none = -1;
25  const int VocabID_seqbeg = -3;
26  const int VocabID_seqend = -2;
27  static const char* Word_beg = "<s>";
28  static const char* Word_end = "</s>";
29 
34  class Vocab
35  {
36  public:
41 
42  public:
43  Vocab();
44  Vocab(const char* pathVocab);
45  ~Vocab();
47  int GetSize() { return m_aWords.GetNum(); }
49  int GetClassNum() { return m_aClass2Word.GetNum(); }
51  const char* GetWordStr(int id) {
52  switch (id) {
53  case VocabID_seqbeg: return Word_beg; break;
54  case VocabID_seqend: return Word_end; break;
55  default: return m_aWords[id].GetBuffer();
56  }
57  return NULL;
58  }
60  VocabID *GetClassMap() { return m_aClass.GetBuffer(); }
62  VocabID GetClass(VocabID wid) {
63  if (wid >= m_aClass.GetNum())
64  return VocabID_none;
65  return m_aClass[wid];
66  }
68  void GetClass(VocabID *pcid, const VocabID *pwid, int nlen);
70  VocabID RandClass() {
71  if (GetClassNum() == 0)
72  return VocabID_none;
73  return rand() % GetClassNum();
74  }
76  Array<int> *GetWord(VocabID cid) {
77  if (cid == VocabID_none) // if no class, then return the word id.
78  return &m_aWordID;
79  return m_aClass2Word[cid];
80  }
82  int IterBeg() const { return 0; }
84  int IterEnd() const { return m_aWords.GetNum() - 1; }
86  bool IsLegalWord(VocabID id) const { return (id >= IterBeg() && id <= IterEnd()); }
87  };
88 }
int GetClassNum()
get the total class number
Definition: trf-vocab.h:49
int VocabID
Definition: trf-vocab.h:23
const int VocabID_none
Definition: trf-vocab.h:24
const int VocabID_seqbeg
Definition: trf-vocab.h:25
Array< VocabID > m_aClass
store the classes of each word. Support soft and hard class
Definition: trf-vocab.h:39
Array< String > m_aWords
the string of each vocabulary id
Definition: trf-vocab.h:38
VocabID GetClass(VocabID wid)
get class
Definition: trf-vocab.h:62
VocabID * GetClassMap()
get class map
Definition: trf-vocab.h:60
T * GetBuffer(int i=0) const
get the buffer pointer
Definition: wb-vector.h:97
VocabID RandClass()
random a class
Definition: trf-vocab.h:70
const int VocabID_seqend
Definition: trf-vocab.h:26
Array< int > * GetWord(VocabID cid)
get word belonging to a class
Definition: trf-vocab.h:76
int GetNum() const
Get Array number.
Definition: wb-vector.h:240
int IterEnd() const
iter all the words, regardless the beg/end symbols
Definition: trf-vocab.h:84
int GetSize()
get the vocab size, i.e. the word number
Definition: trf-vocab.h:47
const char * GetWordStr(int id)
get word string
Definition: trf-vocab.h:51
int IterBeg() const
iter all the words, regardless the beg/end symbols
Definition: trf-vocab.h:82
Definition: trf-alg.cpp:20
bool IsLegalWord(VocabID id) const
Check if the VocabID is a legal word.
Definition: trf-vocab.h:86
Array< Array< VocabID > * > m_aClass2Word
store the word belonging to each class.
Definition: trf-vocab.h:40
Array< VocabID > m_aWordID
the word id. i.e 0,1,2,3,...
Definition: trf-vocab.h:37