游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// Copyright 2020  Jiayu DU
 
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifdef HAVE_KENLM
#include "lm/kenlm.h"
 
namespace kaldi {
 
void KenLm::ComputeSymbolToWordIndexMapping(std::string symbol_table_filename) {
  // count symbol table size
  int num_syms = 0;
  std::string line;
  std::ifstream is(symbol_table_filename);
  while(std::getline(is, line)) {
    if (!line.empty()) num_syms++;
  }
 
  symid_to_wid_.clear();
  symid_to_wid_.resize(num_syms, 0);
 
  is.clear();
  is.seekg(0);
  int num_mapped = 0;
  while (std::getline(is, line)) {
    std::vector<std::string> fields;
    SplitStringToVector(line, " ", true, &fields);
    if (fields.size() == 2) {
      std::string sym = fields[0];
      int32 symid = 0;  ConvertStringToInteger(fields[1], &symid);
      // mark special LM word
      if (sym == bos_sym_) {
        bos_symid_ = symid;
      } else if (sym == eos_sym_) {
        eos_symid_ = symid;
      } else if (sym == "<unk>" || sym == "<UNK>") {
        unk_sym_ = sym;
        unk_symid_ = symid;
      }
      // check vocabulary consistency between kaldi and kenlm.
      // note we always handle <unk> & <UNK> as a pair, 
      // so don't worry about the literal mismatch
      // between Kaldi and kenlm arpa (<UNK> vs <unk>)
      WordIndex wid = vocab_->Index(sym.c_str());
      if ((wid == vocab_->Index("<unk>") || wid == vocab_->Index("<UNK>")) 
          && sym != "<unk>" && sym != "<UNK>"
          && sym != "<eps>"
          && sym != "#0") {
        KALDI_ERR << "found mismatched symbol: " << sym
                  << ", this symbol is in Kaldi, but is unseen in KenLm"
                  << ", they should have strictly consistent vocabulary.";
      } else {
        symid_to_wid_[symid] = wid;
        num_mapped += 1;
      }
    }
  }
  KALDI_ASSERT(num_mapped == symid_to_wid_.size());
  KALDI_LOG << "Successfully mapped " << num_mapped 
            << " Kaldi symbols to KenLm words";
}
 
int KenLm::Load(std::string kenlm_filename,
                std::string symbol_table_filename,
                util::LoadMethod load_method) {
  if (model_ != nullptr) { delete model_; }
  model_ = nullptr;
  vocab_ = nullptr;
 
  // load KenLm model
  lm::ngram::Config config;
  config.load_method = load_method;
  model_ = lm::ngram::LoadVirtual(kenlm_filename.c_str(), config);
  if (model_ == nullptr) { KALDI_ERR << "Failed to load KenLm model"; }
 
  // KenLm holds vocabulary internally with ownership,
  // vocab_ here is just for concise reference
  vocab_ = &model_->BaseVocabulary();
  if (vocab_ == nullptr) { KALDI_ERR << "Failed to get vocabulary from KenLm model"; }
 
  // compute the index mapping from Kaldi symbol to KenLm word
  ComputeSymbolToWordIndexMapping(symbol_table_filename);
  return 0;
}
 
} // namespace kaldi
#endif