雾聪
2024-03-13 7675a2a0baa30357da00263186964c0d0d814581
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <wfst-decoder.h>
namespace funasr {
WfstDecoder::WfstDecoder(fst::Fst<fst::StdArc>* lm,
                         PhoneSet* phone_set, Vocab* vocab,
                         float glob_beam, float lat_beam, float am_scale)
:dec_opts_(glob_beam, lat_beam, am_scale), decodable_(dec_opts_.acoustic_scale),
 lm_(lm), phone_set_(phone_set), vocab_(vocab) {
  decoder_ = std::shared_ptr<kaldi::LatticeFasterOnlineDecoder>(
             new kaldi::LatticeFasterOnlineDecoder(*lm_, dec_opts_));
}
 
WfstDecoder::~WfstDecoder() {
}
 
void WfstDecoder::StartUtterance() {
  if (decoder_) {
    cur_frame_ = 0;
    cur_token_ = 0;
    decodable_.Reset();
    decoder_->InitDecoding();
  }
}
 
void WfstDecoder::EndUtterance() {
}
 
string WfstDecoder::Search(float *in, int len, int64_t token_num) {
  string result;
  if (len == 0) {
    return "";
  }
  std::vector<std::vector<float>> logp_vec;
  int blk_phn_id = phone_set_->GetBlkPhnId();
  for (int i = 0; i < len - 1; i++) {
    std::vector<float> tmp_logp;
    for (int j = 0; j < token_num; j++) {
      tmp_logp.push_back((in + i * token_num)[j]);
    }
    logp_vec.push_back(tmp_logp);
  }
  for (int i = 0; i < logp_vec.size(); i++) {
    cur_frame_++;
    decodable_.AcceptLoglikes(logp_vec[i]);
    decoder_->AdvanceDecoding(&decodable_, 1);
    cur_token_++;
  }
  if (cur_token_ > 0) {
    std::vector<int> words;
    kaldi::Lattice lattice;
    decoder_->GetBestPath(&lattice, false);
    std::vector<int> alignment;
    kaldi::LatticeWeight weight;
    fst::GetLinearSymbolSequence(lattice, &alignment, &words, &weight);
    result = vocab_->Vector2StringV2(words);
  }
  return result;
}
 
string WfstDecoder::FinalizeDecode(bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak) {
  string result;
  if (cur_token_ > 0) {
    std::vector<int> words;
    kaldi::Lattice lattice;
    decodable_.SetFinished();
    decoder_->FinalizeDecoding();
    decoder_->GetBestPath(&lattice, true);
    std::vector<int> alignment;
    kaldi::LatticeWeight weight;
    fst::GetLinearSymbolSequence(lattice, &alignment, &words, &weight);
    
    if(!is_stamp){
        return vocab_->Vector2StringV2(words);
    }else{
        std::vector<std::string> char_list;
        std::vector<std::vector<float>> timestamp_list;
        std::string res_str;
        vocab_->Vector2String(words, char_list);
        // split chinese word to char
        std::vector<std::string> split_chars;
        for(auto& word:char_list){
          std::vector<std::string> word2char;
          SplitChiEngCharacters(word, word2char);
          split_chars.insert(split_chars.end(), word2char.begin(), word2char.end());
        }
        // std::vector<string> raw_char(char_list);
        TimestampOnnx(us_alphas, us_cif_peak, split_chars, res_str, timestamp_list);
 
        return PostProcess(split_chars, timestamp_list);
    }
  }
  return result;
}
 
void WfstDecoder::LoadHwsRes(int inc_bias, unordered_map<string, int> &hws_map) {
  try {
    if (!hws_map.empty()) {
      bias_lm_ = std::make_shared<BiasLm>(hws_map, inc_bias,
                                          *phone_set_, *vocab_);
      decoder_->SetBiasLm(bias_lm_);
    }
  } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load wfst hotwords resource: " << e.what();
        exit(0);
  }
}
 
void WfstDecoder::UnloadHwsRes() {
  if (bias_lm_) {
    decoder_->ClearBiasLm();
    bias_lm_.reset();
  }
}
 
} // namespace funasr