From 3c83d64c84602de055f503af7d4e2761c829ec2e Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 12 十二月 2023 11:11:02 +0800
Subject: [PATCH] fst: support eng hotword

---
 runtime/onnxruntime/src/vocab.h             |    6 ++
 runtime/websocket/bin/funasr-wss-server.cpp |    2 
 runtime/onnxruntime/src/paraformer.h        |    2 
 runtime/onnxruntime/include/com-define.h    |    1 
 runtime/onnxruntime/src/vocab.cpp           |   40 ++++++++++++++++++--
 runtime/onnxruntime/include/model.h         |    2 
 runtime/onnxruntime/src/offline-stream.cpp  |   10 ++++-
 runtime/onnxruntime/src/bias-lm.h           |   30 ++++++++++-----
 runtime/onnxruntime/src/paraformer.cpp      |    5 +-
 9 files changed, 76 insertions(+), 22 deletions(-)

diff --git a/runtime/onnxruntime/include/com-define.h b/runtime/onnxruntime/include/com-define.h
index 57908e6..a2745da 100644
--- a/runtime/onnxruntime/include/com-define.h
+++ b/runtime/onnxruntime/include/com-define.h
@@ -68,6 +68,7 @@
 #define QUANT_DECODER_NAME "decoder_quant.onnx"
 
 #define LM_FST_RES "TLG.fst"
+#define LEX_PATH "lexicon.txt"
 
 // vad
 #ifndef VAD_SILENCE_DURATION
diff --git a/runtime/onnxruntime/include/model.h b/runtime/onnxruntime/include/model.h
index 356fca3..7b58e92 100644
--- a/runtime/onnxruntime/include/model.h
+++ b/runtime/onnxruntime/include/model.h
@@ -15,7 +15,7 @@
     virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
     virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
     virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
-    virtual void InitLm(const std::string &lm_file, const std::string &lm_config){};
+    virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
     virtual void InitFstDecoder(){};
     virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
     virtual std::string Rescoring() = 0;
diff --git a/runtime/onnxruntime/src/bias-lm.h b/runtime/onnxruntime/src/bias-lm.h
index e2d28a2..957197a 100644
--- a/runtime/onnxruntime/src/bias-lm.h
+++ b/runtime/onnxruntime/src/bias-lm.h
@@ -65,12 +65,17 @@
       if (text.size() > 1) {
         score = std::stof(text[1]);
       }
-      Utf8ToCharset(text[0], split_str);
+      SplitChiEngCharacters(text[0], split_str);
       for (auto &str : split_str) {
-        split_id.push_back(phn_set_.String2Id(str));
-        if (!phn_set_.Find(str)) {
-          is_oov = true;
-          break;
+        std::vector<string> lex_vec;
+        std::string lex_str = vocab_.Word2Lex(str);
+        SplitStringToVector(lex_str, " ", true, &lex_vec);
+        for (auto &token : lex_vec) {
+          split_id.push_back(phn_set_.String2Id(token));
+          if (!phn_set_.Find(token)) {
+            is_oov = true;
+            break;
+          }
         }
       }
       if (!is_oov) {
@@ -103,12 +108,17 @@
       std::vector<std::string> split_str;
       std::vector<int> split_id;
       score = kv.second;
-      Utf8ToCharset(kv.first, split_str);
+      SplitChiEngCharacters(kv.first, split_str);
       for (auto &str : split_str) {
-        split_id.push_back(phn_set_.String2Id(str));
-        if (!phn_set_.Find(str)) {
-          is_oov = true;
-          break;
+        std::vector<string> lex_vec;
+        std::string lex_str = vocab_.Word2Lex(str);
+        SplitStringToVector(lex_str, " ", true, &lex_vec);
+        for (auto &token : lex_vec) {
+          split_id.push_back(phn_set_.String2Id(token));
+          if (!phn_set_.Find(token)) {
+            is_oov = true;
+            break;
+          }
         }
       }
       if (!is_oov) {
diff --git a/runtime/onnxruntime/src/offline-stream.cpp b/runtime/onnxruntime/src/offline-stream.cpp
index 2709ca6..ae8cf18 100644
--- a/runtime/onnxruntime/src/offline-stream.cpp
+++ b/runtime/onnxruntime/src/offline-stream.cpp
@@ -63,10 +63,16 @@
 
     // Lm resource
     if (model_path.find(LM_DIR) != model_path.end() && model_path.at(LM_DIR) != "") {
-        string fst_path, lm_config_path, hws_path;
+        string fst_path, lm_config_path, lex_path;
         fst_path = PathAppend(model_path.at(LM_DIR), LM_FST_RES);
         lm_config_path = PathAppend(model_path.at(LM_DIR), LM_CONFIG_NAME);
-        asr_handle->InitLm(fst_path, lm_config_path);
+        lex_path = PathAppend(model_path.at(LM_DIR), LEX_PATH);
+        if (access(lex_path.c_str(), F_OK) != 0 )
+        {
+            LOG(ERROR) << "Lexicon.txt file is not exist, please use the latest version. Skip load LM model.";
+        }else{
+            asr_handle->InitLm(fst_path, lm_config_path, lex_path);
+        }
     }
 
     // PUNC model
diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index b3dc619..3de3e39 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -187,13 +187,14 @@
 }
 
 void Paraformer::InitLm(const std::string &lm_file, 
-                        const std::string &lm_cfg_file) {
+                        const std::string &lm_cfg_file, 
+                        const std::string &lex_file) {
     try {
         lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
             fst::Fst<fst::StdArc>::Read(lm_file));
         if (lm_){
             if (vocab) { delete vocab; }
-            vocab = new Vocab(lm_cfg_file.c_str());
+            vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
             LOG(INFO) << "Successfully load lm file " << lm_file;
         }else{
             LOG(ERROR) << "Failed to load lm file " << lm_file;
diff --git a/runtime/onnxruntime/src/paraformer.h b/runtime/onnxruntime/src/paraformer.h
index b5bc46d..89c8b09 100644
--- a/runtime/onnxruntime/src/paraformer.h
+++ b/runtime/onnxruntime/src/paraformer.h
@@ -60,7 +60,7 @@
 		
         void StartUtterance();
         void EndUtterance();
-        void InitLm(const std::string &lm_file, const std::string &lm_cfg_file);
+        void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
         string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
         string FinalizeDecode(WfstDecoder* &wfst_decoder,
                           bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
diff --git a/runtime/onnxruntime/src/vocab.cpp b/runtime/onnxruntime/src/vocab.cpp
index 20571c9..6991376 100644
--- a/runtime/onnxruntime/src/vocab.cpp
+++ b/runtime/onnxruntime/src/vocab.cpp
@@ -16,6 +16,12 @@
     ifstream in(filename);
     LoadVocabFromYaml(filename);
 }
+Vocab::Vocab(const char *filename, const char *lex_file)
+{
+    ifstream in(filename);
+    LoadVocabFromYaml(filename);
+    LoadLex(lex_file);
+}
 Vocab::~Vocab()
 {
 }
@@ -37,11 +43,37 @@
     }
 }
 
-int Vocab::GetIdByToken(const std::string &token) {
-    if (token_id.count(token)) {
-        return token_id[token];
+void Vocab::LoadLex(const char* filename){
+    std::ifstream file(filename);
+    std::string line;
+    while (std::getline(file, line)) {
+        std::string key, value;
+        std::istringstream iss(line);
+        std::getline(iss, key, '\t');
+        std::getline(iss, value);
+
+        if (!key.empty() && !value.empty()) {
+            lex_map[key] = value;
+        }
     }
-    return 0;
+
+    file.close();
+}
+
+string Vocab::Word2Lex(const std::string &word) const {
+    auto it = lex_map.find(word);
+    if (it != lex_map.end()) {
+        return it->second;
+    }
+    return "";
+}
+
+int Vocab::GetIdByToken(const std::string &token) const {
+    auto it = token_id.find(token);
+    if (it != token_id.end()) {
+        return it->second;
+    }
+    return -1;
 }
 
 void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
diff --git a/runtime/onnxruntime/src/vocab.h b/runtime/onnxruntime/src/vocab.h
index 8834b97..19e3648 100644
--- a/runtime/onnxruntime/src/vocab.h
+++ b/runtime/onnxruntime/src/vocab.h
@@ -13,11 +13,14 @@
   private:
     vector<string> vocab;
     std::map<string, int> token_id;
+    std::map<string, string> lex_map;
     bool IsEnglish(string ch);
     void LoadVocabFromYaml(const char* filename);
+    void LoadLex(const char* filename);
 
   public:
     Vocab(const char *filename);
+    Vocab(const char *filename, const char *lex_file);
     ~Vocab();
     int Size() const;
     bool IsChinese(string ch);
@@ -26,7 +29,8 @@
     string Vector2StringV2(vector<int> in, std::string language="");
     string Id2String(int id) const;
     string WordFormat(std::string word);
-    int GetIdByToken(const std::string &token);
+    int GetIdByToken(const std::string &token) const;
+    string Word2Lex(const std::string &word) const;
 };
 
 } // namespace funasr
diff --git a/runtime/websocket/bin/funasr-wss-server.cpp b/runtime/websocket/bin/funasr-wss-server.cpp
index f54bc5b..67e3309 100644
--- a/runtime/websocket/bin/funasr-wss-server.cpp
+++ b/runtime/websocket/bin/funasr-wss-server.cpp
@@ -111,7 +111,7 @@
     TCLAP::ValueArg<std::string> lm_dir("", LM_DIR,
         "the LM model path, which contains compiled models: TLG.fst, config.yaml ", false, "damo/speech_ngram_lm_zh-cn-ai-wesp-fst", "string");
     TCLAP::ValueArg<std::string> lm_revision(
-        "", "lm-revision", "LM model revision", false, "v1.0.1", "string");
+        "", "lm-revision", "LM model revision", false, "v1.0.2", "string");
     TCLAP::ValueArg<std::string> hotword("", HOTWORD,
         "the hotword file, one hotword perline, Format: Hotword Weight (could be: 闃块噷宸村反 20)", 
         false, "/workspace/resources/hotwords.txt", "string");

--
Gitblit v1.9.1