Yabin Li
2023-11-07 702ec03ad89d5c62e97eed770a6882d6412f8d58
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#ifndef BIAS_LM_
#define BIAS_LM_
#include <assert.h>
#include "util.h"
#include "fst/fstlib.h"
#include "phone-set.h"
#include "vocab.h"
#include "util/text-utils.h"
#include <yaml-cpp/yaml.h>
// node type
#define ROOT_NODE 0
#define VALUE_ZERO 0.0f
 
namespace funasr {
typedef fst::StdArc Arc;
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
typedef typename Arc::Label Label;
typedef typename fst::SortedMatcher<fst::StdVectorFst> Matcher;
typedef typename fst::ArcIterator<fst::StdVectorFst> ArcIterator;
 
class Node {
 public:
  Node() : score_(0.0f), is_final_(false), back_off_(-1) {}
  float score_;
  bool is_final_;
  StateId back_off_;
};
 
class BiasLmOption {
 public:
  BiasLmOption() : incre_bias_(20.0f), scale_(1.0f) {}
  float incre_bias_;
  float scale_;
};
 
class BiasLm {
 public:
  BiasLm(const string &hws_file, const string &cfg_file, 
    const PhoneSet& phn_set, const Vocab& vocab) :
    phn_set_(phn_set), vocab_(vocab) {
    std::string line;
    std::ifstream ifs_hws(hws_file.c_str());
    std::vector<float> custom_weight;
    std::vector<std::vector<int>> split_id_vec;
 
    struct timeval start, end;
    gettimeofday(&start, NULL);
 
    LoadCfgFromYaml(cfg_file.c_str(), opt_);
    while (getline(ifs_hws, line)) {
      Trim(&line);
      if (line.empty()) {
        continue;
      }
      float score = 1.0f;
      bool is_oov = false;
      std::vector<std::string> text;
      std::vector<std::string> split_str;
      std::vector<int> split_id;
      SplitStringToVector(line, "\t", true, &text);
      if (text.size() > 1) {
        score = std::stof(text[1]);
      }
      Utf8ToCharset(text[0], split_str);
      for (auto &str : split_str) {
        split_id.push_back(phn_set_.String2Id(str));
        if (!phn_set_.Find(str)) {
          is_oov = true;
          break;
        }
      }
      if (!is_oov) {
        split_id_vec.push_back(split_id);
        custom_weight.push_back(score);
      }
    }
    BuildGraph(split_id_vec, custom_weight);
    ifs_hws.close();
 
    gettimeofday(&end, NULL);
    long seconds = (end.tv_sec - start.tv_sec);
    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
    LOG(INFO) << "Build bias lm takes " << (double)modle_init_micros / 1000000 << " s";
  }
 
  BiasLm(unordered_map<string, int> &hws_map, int inc_bias,
    const PhoneSet& phn_set, const Vocab& vocab) :
    phn_set_(phn_set), vocab_(vocab) {
    std::vector<float> custom_weight;
    std::vector<std::vector<int>> split_id_vec;
 
    struct timeval start, end;
    gettimeofday(&start, NULL);
    opt_.incre_bias_ = inc_bias;
    for (const pair<string, int>& kv : hws_map) {
      float score = 1.0f;
      bool is_oov = false;
      std::vector<std::string> text;
      std::vector<std::string> split_str;
      std::vector<int> split_id;
      score = kv.second;
      Utf8ToCharset(kv.first, split_str);
      for (auto &str : split_str) {
        split_id.push_back(phn_set_.String2Id(str));
        if (!phn_set_.Find(str)) {
          is_oov = true;
          break;
        }
      }
      if (!is_oov) {
        split_id_vec.push_back(split_id);
        custom_weight.push_back(score);
      }
    }
    BuildGraph(split_id_vec, custom_weight);
 
    gettimeofday(&end, NULL);
    long seconds = (end.tv_sec - start.tv_sec);
    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
    LOG(INFO) << "Build bias lm takes " << (double)modle_init_micros / 1000000 << " s";
  }
 
  void BuildGraph(std::vector<std::vector<int>> &vec, std::vector<float> &wts);
  float BiasLmScore(const StateId &cur_state, const Label &lab, Label &new_state);
  void VocabIdToPhnIdVector(int vocab_id, std::vector<int> &phn_ids);
  void LoadCfgFromYaml(const char* filename, BiasLmOption &opt);
  std::string GetPhoneLabel(int phone_id);
 private:
  const PhoneSet& phn_set_;
  const Vocab& vocab_;
  std::unique_ptr<fst::StdVectorFst> graph_ = nullptr;
  std::vector<Node> node_list_;
  BiasLmOption opt_;
};
} // namespace funasr
#endif // BIAS_LM_