雾聪
2023-12-12 3c83d64c84602de055f503af7d4e2761c829ec2e
fst: support eng hotword
9个文件已修改
98 ■■■■ 已修改文件
runtime/onnxruntime/include/com-define.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/model.h 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/bias-lm.h 30 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/offline-stream.cpp 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer.cpp 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer.h 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/vocab.cpp 40 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/vocab.h 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/funasr-wss-server.cpp 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/com-define.h
@@ -68,6 +68,7 @@
#define QUANT_DECODER_NAME "decoder_quant.onnx"
#define LM_FST_RES "TLG.fst"
#define LEX_PATH "lexicon.txt"
// vad
#ifndef VAD_SILENCE_DURATION
runtime/onnxruntime/include/model.h
@@ -15,7 +15,7 @@
    virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
    virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
    virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
    virtual void InitLm(const std::string &lm_file, const std::string &lm_config){};
    virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
    virtual void InitFstDecoder(){};
    virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
    virtual std::string Rescoring() = 0;
runtime/onnxruntime/src/bias-lm.h
@@ -65,12 +65,17 @@
      if (text.size() > 1) {
        score = std::stof(text[1]);
      }
      Utf8ToCharset(text[0], split_str);
      SplitChiEngCharacters(text[0], split_str);
      for (auto &str : split_str) {
        split_id.push_back(phn_set_.String2Id(str));
        if (!phn_set_.Find(str)) {
          is_oov = true;
          break;
        std::vector<string> lex_vec;
        std::string lex_str = vocab_.Word2Lex(str);
        SplitStringToVector(lex_str, " ", true, &lex_vec);
        for (auto &token : lex_vec) {
          split_id.push_back(phn_set_.String2Id(token));
          if (!phn_set_.Find(token)) {
            is_oov = true;
            break;
          }
        }
      }
      if (!is_oov) {
@@ -103,12 +108,17 @@
      std::vector<std::string> split_str;
      std::vector<int> split_id;
      score = kv.second;
      Utf8ToCharset(kv.first, split_str);
      SplitChiEngCharacters(kv.first, split_str);
      for (auto &str : split_str) {
        split_id.push_back(phn_set_.String2Id(str));
        if (!phn_set_.Find(str)) {
          is_oov = true;
          break;
        std::vector<string> lex_vec;
        std::string lex_str = vocab_.Word2Lex(str);
        SplitStringToVector(lex_str, " ", true, &lex_vec);
        for (auto &token : lex_vec) {
          split_id.push_back(phn_set_.String2Id(token));
          if (!phn_set_.Find(token)) {
            is_oov = true;
            break;
          }
        }
      }
      if (!is_oov) {
runtime/onnxruntime/src/offline-stream.cpp
@@ -63,10 +63,16 @@
    // Lm resource
    if (model_path.find(LM_DIR) != model_path.end() && model_path.at(LM_DIR) != "") {
        string fst_path, lm_config_path, hws_path;
        string fst_path, lm_config_path, lex_path;
        fst_path = PathAppend(model_path.at(LM_DIR), LM_FST_RES);
        lm_config_path = PathAppend(model_path.at(LM_DIR), LM_CONFIG_NAME);
        asr_handle->InitLm(fst_path, lm_config_path);
        lex_path = PathAppend(model_path.at(LM_DIR), LEX_PATH);
        if (access(lex_path.c_str(), F_OK) != 0 )
        {
            LOG(ERROR) << "Lexicon.txt file is not exist, please use the latest version. Skip load LM model.";
        }else{
            asr_handle->InitLm(fst_path, lm_config_path, lex_path);
        }
    }
    // PUNC model
runtime/onnxruntime/src/paraformer.cpp
@@ -187,13 +187,14 @@
}
void Paraformer::InitLm(const std::string &lm_file, 
                        const std::string &lm_cfg_file) {
                        const std::string &lm_cfg_file,
                        const std::string &lex_file) {
    try {
        lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
            fst::Fst<fst::StdArc>::Read(lm_file));
        if (lm_){
            if (vocab) { delete vocab; }
            vocab = new Vocab(lm_cfg_file.c_str());
            vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
            LOG(INFO) << "Successfully load lm file " << lm_file;
        }else{
            LOG(ERROR) << "Failed to load lm file " << lm_file;
runtime/onnxruntime/src/paraformer.h
@@ -60,7 +60,7 @@
        
        void StartUtterance();
        void EndUtterance();
        void InitLm(const std::string &lm_file, const std::string &lm_cfg_file);
        void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
        string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
        string FinalizeDecode(WfstDecoder* &wfst_decoder,
                          bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
runtime/onnxruntime/src/vocab.cpp
@@ -16,6 +16,12 @@
    ifstream in(filename);
    LoadVocabFromYaml(filename);
}
Vocab::Vocab(const char *filename, const char *lex_file)
{
    ifstream in(filename);
    LoadVocabFromYaml(filename);
    LoadLex(lex_file);
}
Vocab::~Vocab()
{
}
@@ -37,11 +43,37 @@
    }
}
int Vocab::GetIdByToken(const std::string &token) {
    if (token_id.count(token)) {
        return token_id[token];
void Vocab::LoadLex(const char* filename){
    std::ifstream file(filename);
    std::string line;
    while (std::getline(file, line)) {
        std::string key, value;
        std::istringstream iss(line);
        std::getline(iss, key, '\t');
        std::getline(iss, value);
        if (!key.empty() && !value.empty()) {
            lex_map[key] = value;
        }
    }
    return 0;
    file.close();
}
string Vocab::Word2Lex(const std::string &word) const {
    auto it = lex_map.find(word);
    if (it != lex_map.end()) {
        return it->second;
    }
    return "";
}
int Vocab::GetIdByToken(const std::string &token) const {
    auto it = token_id.find(token);
    if (it != token_id.end()) {
        return it->second;
    }
    return -1;
}
void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
runtime/onnxruntime/src/vocab.h
@@ -13,11 +13,14 @@
  private:
    vector<string> vocab;
    std::map<string, int> token_id;
    std::map<string, string> lex_map;
    bool IsEnglish(string ch);
    void LoadVocabFromYaml(const char* filename);
    void LoadLex(const char* filename);
  public:
    Vocab(const char *filename);
    Vocab(const char *filename, const char *lex_file);
    ~Vocab();
    int Size() const;
    bool IsChinese(string ch);
@@ -26,7 +29,8 @@
    string Vector2StringV2(vector<int> in, std::string language="");
    string Id2String(int id) const;
    string WordFormat(std::string word);
    int GetIdByToken(const std::string &token);
    int GetIdByToken(const std::string &token) const;
    string Word2Lex(const std::string &word) const;
};
} // namespace funasr
runtime/websocket/bin/funasr-wss-server.cpp
@@ -111,7 +111,7 @@
    TCLAP::ValueArg<std::string> lm_dir("", LM_DIR,
        "the LM model path, which contains compiled models: TLG.fst, config.yaml ", false, "damo/speech_ngram_lm_zh-cn-ai-wesp-fst", "string");
    TCLAP::ValueArg<std::string> lm_revision(
        "", "lm-revision", "LM model revision", false, "v1.0.1", "string");
        "", "lm-revision", "LM model revision", false, "v1.0.2", "string");
    TCLAP::ValueArg<std::string> hotword("", HOTWORD,
        "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", 
        false, "/workspace/resources/hotwords.txt", "string");