游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
// Copyright 2020  Jiayu DU
 
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifdef HAVE_KENLM
#ifndef KALDI_LM_KENLM_H
#define KALDI_LM_KENLM_H
 
#include <base/kaldi-common.h>
#include <fst/fstlib.h>
#include <fst/fst-decl.h>
#include <fstext/deterministic-fst.h>
 
#include "lm/model.hh"
#include "util/murmur_hash.hh"
 
namespace kaldi {
 
// KenLm class wraps kenlm model(supporting both "trie" or "probing" models):
//  1. provides interface for loading binary LM, and holds it with ownership
//  2. provides interface for ngram score query at runtime
//  3. handles the index mapping between kaldi's symbols & kenlm's words
// KenLm object is heavy, stateless and thread-safe, 
// can be shared by Fst wrapper class(i.e. KenLmDeterministicOnDemandFst)
class KenLm {
 public:
  typedef lm::WordIndex WordIndex;
  typedef lm::ngram::State State;
 
 public:
  KenLm() : 
    model_(nullptr), vocab_(nullptr),
    bos_sym_("<s>"), eos_sym_("</s>"), unk_sym_("<unk>"),
    bos_symid_(0), eos_symid_(0), unk_symid_(0)
  { }
 
  ~KenLm() {
    if (model_ != nullptr) {
      delete model_;
    }
    model_ = nullptr;
    vocab_ = nullptr;
    symid_to_wid_.clear();
  }
 
  // If you have big LM on SSD hard-drive,
  // you can set load_method to util::LoadMethod::LAZY,
  // which enables "on-demand" model reading(via POSIX mmap) at runtime.
  // Refer to tools/kenlm/util/mmap.hh for more load methods.
  int Load(std::string kenlm_filename, 
           std::string kaldi_symbol_table_filename,
           util::LoadMethod load_method = util::LoadMethod::POPULATE_OR_READ);
 
  inline WordIndex GetWordIndex(std::string word) const {
    return vocab_->Index(word.c_str());
  }
 
  inline WordIndex GetWordIndex(int32 symbol_id) const {
    return symid_to_wid_[symbol_id];
  }
 
  void SetStateToBeginOfSentence(State *s) const { model_->BeginSentenceWrite(s); }
  void SetStateToNull(State *s) const { model_->NullContextWrite(s); }
 
  int32 BosSymbolIndex() const { return bos_symid_; }
  int32 EosSymbolIndex() const { return eos_symid_; }
  int32 UnkSymbolIndex() const { return unk_symid_; }
 
  inline BaseFloat Score(const State *in_state,
                         WordIndex word,
                         State *out_state) const {
    return model_->BaseScore(in_state, word, out_state);
  }
 
  // This provides a fast state hashing, 
  // KenLmDeterministicOnDemandFst needs this for Fst states managing.
  struct StateHasher {
    inline size_t operator()(const State &s) const noexcept {
      return util::MurmurHashNative(s.words, sizeof(WordIndex) * s.Length());
    }
  };
 
 private:
  void ComputeSymbolToWordIndexMapping(std::string symbol_table);
  
 private:
  lm::base::Model *model_; // with ownership
 
  // without ownership, points to internal vocabulary of model_
  const lm::base::Vocabulary* vocab_;
 
  // There are two integerized indexing systems here:
  // 1. Kaldi's fst output *symbol index*(defined in words.txt),
  // 2. KenLm's *word index*(defined by word string hashing).
  // In order to rescore kaldi hypotheses with kenlm ngrams, 
  // we need to know the index mapping from symbol to word.
  // KenLm class precomputes (during model loading) and stores this mapping,
  // and apply the mapping at runtime.
  // This is slower, but at least we don't need
  // to modify/convert runtime resources.(e.g. HCLG/lattices or kenlm models)
  //
  // In the mapping, <eps> and #0 symbols are special:
  // They do not correspond to any word in KenLm,
  // so the mapping of these two symbols are logically undefined,
  // we just map them to KenLm's <unk> to avoid random invalid mapping.
 
  // symid_to_wid_[kaldi_symbol_index] -> kenlm word index
  std::vector<WordIndex> symid_to_wid_;
 
  // special lm symbols
  std::string bos_sym_;
  std::string eos_sym_;
  std::string unk_sym_;
 
  int32 bos_symid_;
  int32 eos_symid_;
  int32 unk_symid_;
}; // class KenLm
 
 
// DeterministicOnDemandFst wraps a KenLm object as a deteministic Fst.
// Internally, it manages dynamically expanded Fst states(so not thread-safe),
// different threads should create their own instances of this class.
// They are lightweight and can share the same KenLm object.
template<class Arc>
class KenLmDeterministicOnDemandFst : public fst::DeterministicOnDemandFst<Arc> {
 public:
  typedef typename Arc::Weight Weight;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename KenLm::State State;
  typedef typename KenLm::WordIndex WordIndex;
 
  explicit KenLmDeterministicOnDemandFst(const KenLm *lm)
   : lm_(lm), num_states_(0), bos_state_id_(0)
  {
    // create bos to be FST start state
    MapElem e;
    lm->SetStateToBeginOfSentence(&e.first);
    e.second = bos_state_id_;
    std::pair<IterType, bool> r = state_map_.insert(e);
    KALDI_ASSERT(r.second == true); // bos successfully inserted into state map
    state_vec_.push_back(&r.first->first);
    num_states_++;
 
    eos_symbol_id_ = lm_->EosSymbolIndex();
  }
  virtual ~KenLmDeterministicOnDemandFst() { }
 
  virtual StateId Start() { 
    return bos_state_id_;
  }
 
  virtual bool GetArc(StateId s, Label label, Arc *oarc) {
    KALDI_ASSERT(s < static_cast<StateId>(state_vec_.size()));
    const State* istate = state_vec_[s];
    MapElem e;
    WordIndex word = lm_->GetWordIndex(label);
    BaseFloat log_10_prob = lm_->Score(istate, word, &e.first);
    e.second = num_states_;
    std::pair<IterType, bool> r = state_map_.insert(e);
    if (r.second == true) { // new state
      state_vec_.push_back(&(r.first->first));
      num_states_++;
    }
 
    oarc->ilabel = label;
    oarc->olabel = oarc->ilabel;
    oarc->nextstate = r.first->second;
    oarc->weight = Weight(-log_10_prob * M_LN10); // KenLm log10() -> Kaldi ln()
 
    return true;
  }
 
  virtual Weight Final(StateId s) {
    Arc oarc;
    GetArc(s, eos_symbol_id_, &oarc);
    return oarc.weight;
  }
 
 private:
  typedef std::pair<State, StateId> MapElem;
  typedef unordered_map<State, StateId, KenLm::StateHasher> MapType;
  typedef typename MapType::iterator IterType;
 
  const KenLm *lm_; // no ownership
  MapType state_map_;
  std::vector<const State*> state_vec_;
  StateId num_states_; // state vector index range, [0, num_states_)
  StateId bos_state_id_;  // fst start state id
  Label eos_symbol_id_;
}; // class KenLmDeterministicOnDemandFst
} // namespace kaldi
#endif
#endif