From 00cfc36b9a1ad4d114434eb7770c1e67940d4862 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 13 五月 2024 16:52:41 +0800
Subject: [PATCH] c++ runtime adapt to 1.0 (#1724)

---
 runtime/websocket/bin/funasr-wss-server.cpp       |   12 +-
 runtime/onnxruntime/src/paraformer.h              |    6 
 runtime/onnxruntime/include/com-define.h          |    7 +
 runtime/onnxruntime/src/model.cpp                 |    8 +
 runtime/onnxruntime/src/vocab.cpp                 |   21 +++++
 runtime/onnxruntime/include/model.h               |    6 
 runtime/onnxruntime/src/ct-transformer.h          |    2 
 runtime/onnxruntime/src/offline-stream.cpp        |   11 ++
 runtime/onnxruntime/src/tokenizer.cpp             |   55 +++++++++++++
 runtime/onnxruntime/src/ct-transformer-online.h   |    2 
 runtime/onnxruntime/src/paraformer.cpp            |   16 ++--
 runtime/onnxruntime/src/vocab.h                   |    2 
 runtime/onnxruntime/CMakeLists.txt                |   12 +++
 runtime/onnxruntime/src/tokenizer.h               |    2 
 runtime/onnxruntime/src/phone-set.h               |    2 
 runtime/onnxruntime/include/punc-model.h          |    2 
 runtime/onnxruntime/src/tpass-stream.cpp          |   11 ++
 runtime/onnxruntime/src/ct-transformer.cpp        |    4 
 runtime/onnxruntime/src/phone-set.cpp             |   21 +++++
 runtime/onnxruntime/src/ct-transformer-online.cpp |    4 
 runtime/onnxruntime/src/fsmn-vad.cpp              |    2 
 runtime/onnxruntime/src/punc-model.cpp            |    4 
 runtime/websocket/bin/funasr-wss-server-2pass.cpp |   14 +-
 23 files changed, 177 insertions(+), 49 deletions(-)

diff --git a/runtime/onnxruntime/CMakeLists.txt b/runtime/onnxruntime/CMakeLists.txt
index 3450be7..d8e623e 100644
--- a/runtime/onnxruntime/CMakeLists.txt
+++ b/runtime/onnxruntime/CMakeLists.txt
@@ -18,6 +18,17 @@
     message("Little endian system")
 endif()
 
+# json
+include(FetchContent)
+if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/json/ChangeLog.md )
+FetchContent_Declare(json
+  URL   https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz
+SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/json
+)
+
+FetchContent_MakeAvailable(json)
+endif()
+
 # for onnxruntime
 IF(WIN32)
     file(REMOVE ${PROJECT_SOURCE_DIR}/third_party/glog/src/config.h 
@@ -36,6 +47,7 @@
 include_directories(${PROJECT_SOURCE_DIR}/third_party/jieba/include)
 include_directories(${PROJECT_SOURCE_DIR}/third_party/jieba/include/limonp/include)
 include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi)
+include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
 
 if(ENABLE_GLOG)
     include_directories(${PROJECT_SOURCE_DIR}/third_party/glog/src)
diff --git a/runtime/onnxruntime/include/com-define.h b/runtime/onnxruntime/include/com-define.h
index 9cb1f2c..d4edd5b 100644
--- a/runtime/onnxruntime/include/com-define.h
+++ b/runtime/onnxruntime/include/com-define.h
@@ -49,13 +49,14 @@
 // hotword embedding compile model
 #define MODEL_EB_NAME "model_eb.onnx"
 #define QUANT_MODEL_NAME "model_quant.onnx"
-#define VAD_CMVN_NAME "vad.mvn"
-#define VAD_CONFIG_NAME "vad.yaml"
+#define VAD_CMVN_NAME "am.mvn"
+#define VAD_CONFIG_NAME "config.yaml"
 #define AM_CMVN_NAME "am.mvn"
 #define AM_CONFIG_NAME "config.yaml"
 #define LM_CONFIG_NAME "config.yaml"
-#define PUNC_CONFIG_NAME "punc.yaml"
+#define PUNC_CONFIG_NAME "config.yaml"
 #define MODEL_SEG_DICT "seg_dict"
+#define TOKEN_PATH "tokens.json"
 #define HOTWORD "hotword"
 // #define NN_HOTWORD "nn-hotword"
 
diff --git a/runtime/onnxruntime/include/model.h b/runtime/onnxruntime/include/model.h
index 33caec8..f5c4027 100644
--- a/runtime/onnxruntime/include/model.h
+++ b/runtime/onnxruntime/include/model.h
@@ -12,9 +12,9 @@
     virtual void StartUtterance() = 0;
     virtual void EndUtterance() = 0;
     virtual void Reset() = 0;
-    virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
-    virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
-    virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
+    virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
+    virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
+    virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
     virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
     virtual void InitFstDecoder(){};
     virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
diff --git a/runtime/onnxruntime/include/punc-model.h b/runtime/onnxruntime/include/punc-model.h
index 214c770..3cec2c1 100644
--- a/runtime/onnxruntime/include/punc-model.h
+++ b/runtime/onnxruntime/include/punc-model.h
@@ -11,7 +11,7 @@
 class PuncModel {
   public:
     virtual ~PuncModel(){};
-	  virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
+	  virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num)=0;
 	  virtual std::string AddPunc(const char* sz_input, std::string language="zh-cn"){return "";};
 	  virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache, std::string language="zh-cn"){return "";};
 };
diff --git a/runtime/onnxruntime/src/ct-transformer-online.cpp b/runtime/onnxruntime/src/ct-transformer-online.cpp
index 4e9136e..92fe41e 100644
--- a/runtime/onnxruntime/src/ct-transformer-online.cpp
+++ b/runtime/onnxruntime/src/ct-transformer-online.cpp
@@ -11,7 +11,7 @@
 {
 }
 
-void CTTransformerOnline::InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num){
+void CTTransformerOnline::InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num){
     session_options.SetIntraOpNumThreads(thread_num);
     session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
     session_options.DisableCpuMemArena();
@@ -43,7 +43,7 @@
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
 
-	m_tokenizer.OpenYaml(punc_config.c_str());
+	m_tokenizer.OpenYaml(punc_config.c_str(), token_file.c_str());
 }
 
 CTTransformerOnline::~CTTransformerOnline()
diff --git a/runtime/onnxruntime/src/ct-transformer-online.h b/runtime/onnxruntime/src/ct-transformer-online.h
index ea7edb7..13f40a0 100644
--- a/runtime/onnxruntime/src/ct-transformer-online.h
+++ b/runtime/onnxruntime/src/ct-transformer-online.h
@@ -26,7 +26,7 @@
 public:
 
 	CTTransformerOnline();
-	void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
+	void InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num);
 	~CTTransformerOnline();
 	vector<int>  Infer(vector<int32_t> input_data, int nCacheSize);
 	string AddPunc(const char* sz_input, vector<string> &arr_cache, std::string language="zh-cn");
diff --git a/runtime/onnxruntime/src/ct-transformer.cpp b/runtime/onnxruntime/src/ct-transformer.cpp
index 8f8d953..d1a7813 100644
--- a/runtime/onnxruntime/src/ct-transformer.cpp
+++ b/runtime/onnxruntime/src/ct-transformer.cpp
@@ -11,7 +11,7 @@
 {
 }
 
-void CTTransformer::InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num){
+void CTTransformer::InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num){
     session_options.SetIntraOpNumThreads(thread_num);
     session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
     session_options.DisableCpuMemArena();
@@ -39,7 +39,7 @@
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
 
-	m_tokenizer.OpenYaml(punc_config.c_str());
+	m_tokenizer.OpenYaml(punc_config.c_str(), token_file.c_str());
     m_tokenizer.JiebaInit(punc_config);
 }
 
diff --git a/runtime/onnxruntime/src/ct-transformer.h b/runtime/onnxruntime/src/ct-transformer.h
index b33dcf5..f38fe12 100644
--- a/runtime/onnxruntime/src/ct-transformer.h
+++ b/runtime/onnxruntime/src/ct-transformer.h
@@ -26,7 +26,7 @@
 public:
 
 	CTTransformer();
-	void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
+	void InitPunc(const std::string &punc_model, const std::string &punc_config, const std::string &token_file, int thread_num);
 	~CTTransformer();
 	vector<int>  Infer(vector<int32_t> input_data);
 	string AddPunc(const char* sz_input, std::string language="zh-cn");
diff --git a/runtime/onnxruntime/src/fsmn-vad.cpp b/runtime/onnxruntime/src/fsmn-vad.cpp
index c832274..42ce83b 100644
--- a/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -30,7 +30,7 @@
 
     try{
         YAML::Node frontend_conf = config["frontend_conf"];
-        YAML::Node post_conf = config["vad_post_conf"];
+        YAML::Node post_conf = config["model_conf"];
 
         this->vad_sample_rate_ = frontend_conf["fs"].as<int>();
         this->vad_silence_duration_ =  post_conf["max_end_silence_time"].as<int>();
diff --git a/runtime/onnxruntime/src/model.cpp b/runtime/onnxruntime/src/model.cpp
index 646f260..8b5e33f 100644
--- a/runtime/onnxruntime/src/model.cpp
+++ b/runtime/onnxruntime/src/model.cpp
@@ -8,6 +8,7 @@
         string am_model_path;
         string am_cmvn_path;
         string am_config_path;
+        string token_path;
 
         am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
         if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
@@ -15,10 +16,11 @@
         }
         am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
         am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
+        token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
 
         Model *mm;
         mm = new Paraformer();
-        mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
+        mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
         return mm;
     }else if(type == ASR_ONLINE){
         // online
@@ -26,6 +28,7 @@
         string de_model_path;
         string am_cmvn_path;
         string am_config_path;
+        string token_path;
 
         en_model_path = PathAppend(model_path.at(MODEL_DIR), ENCODER_NAME);
         de_model_path = PathAppend(model_path.at(MODEL_DIR), DECODER_NAME);
@@ -35,10 +38,11 @@
         }
         am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
         am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
+        token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
 
         Model *mm;
         mm = new Paraformer();
-        mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
+        mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
         return mm;
     }else{
         LOG(ERROR)<<"Wrong ASR_TYPE : " << type;
diff --git a/runtime/onnxruntime/src/offline-stream.cpp b/runtime/onnxruntime/src/offline-stream.cpp
index ae8cf18..7d86f9b 100644
--- a/runtime/onnxruntime/src/offline-stream.cpp
+++ b/runtime/onnxruntime/src/offline-stream.cpp
@@ -32,6 +32,7 @@
         string am_model_path;
         string am_cmvn_path;
         string am_config_path;
+        string token_path;
         string hw_compile_model_path;
         string seg_dict_path;
     
@@ -57,8 +58,9 @@
         }
         am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
         am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
+        token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
 
-        asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
+        asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
     }
 
     // Lm resource
@@ -79,20 +81,23 @@
     if(model_path.find(PUNC_DIR) != model_path.end()){
         string punc_model_path;
         string punc_config_path;
+        string token_path;
     
         punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
         if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
             punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
         }
         punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
+        token_path = PathAppend(model_path.at(PUNC_DIR), TOKEN_PATH);
 
         if (access(punc_model_path.c_str(), F_OK) != 0 ||
-            access(punc_config_path.c_str(), F_OK) != 0 )
+            access(punc_config_path.c_str(), F_OK) != 0 ||
+            access(token_path.c_str(), F_OK) != 0)
         {
             LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
         }else{
             punc_handle = make_unique<CTTransformer>();
-            punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+            punc_handle->InitPunc(punc_model_path, punc_config_path, token_path, thread_num);
             use_punc = true;
         }
     }
diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index c56421c..a57fb9b 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -18,7 +18,7 @@
 }
 
 // offline
-void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
     LoadConfigFromYaml(am_config.c_str());
     // knf options
     fbank_opts_.frame_opts.dither = 0;
@@ -65,13 +65,13 @@
         m_szInputNames.push_back(item.c_str());
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
-    vocab = new Vocab(am_config.c_str());
-	phone_set_ = new PhoneSet(am_config.c_str());
+    vocab = new Vocab(token_file.c_str());
+	phone_set_ = new PhoneSet(token_file.c_str());
     LoadCmvn(am_cmvn.c_str());
 }
 
 // online
-void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
     
     LoadOnlineConfigFromYaml(am_config.c_str());
     // knf options
@@ -143,15 +143,15 @@
     for (auto& item : de_strOutputNames)
         de_szOutputNames_.push_back(item.c_str());
 
-    vocab = new Vocab(am_config.c_str());
-    phone_set_ = new PhoneSet(am_config.c_str());
+    vocab = new Vocab(token_file.c_str());
+    phone_set_ = new PhoneSet(token_file.c_str());
     LoadCmvn(am_cmvn.c_str());
 }
 
 // 2pass
-void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
     // online
-    InitAsr(en_model, de_model, am_cmvn, am_config, thread_num);
+    InitAsr(en_model, de_model, am_cmvn, am_config, token_file, thread_num);
 
     // offline
     try {
diff --git a/runtime/onnxruntime/src/paraformer.h b/runtime/onnxruntime/src/paraformer.h
index 5bb9477..417c2d7 100644
--- a/runtime/onnxruntime/src/paraformer.h
+++ b/runtime/onnxruntime/src/paraformer.h
@@ -42,11 +42,11 @@
     public:
         Paraformer();
         ~Paraformer();
-        void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+        void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
         // online
-        void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+        void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
         // 2pass
-        void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+        void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
         void InitHwCompiler(const std::string &hw_model, int thread_num);
         void InitSegDict(const std::string &seg_dict_model);
         std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
diff --git a/runtime/onnxruntime/src/phone-set.cpp b/runtime/onnxruntime/src/phone-set.cpp
index 167fa01..60eb101 100644
--- a/runtime/onnxruntime/src/phone-set.cpp
+++ b/runtime/onnxruntime/src/phone-set.cpp
@@ -13,7 +13,7 @@
 namespace funasr {
 PhoneSet::PhoneSet(const char *filename) {
   ifstream in(filename);
-  LoadPhoneSetFromYaml(filename);
+  LoadPhoneSetFromJson(filename);
 }
 PhoneSet::~PhoneSet()
 {
@@ -35,6 +35,25 @@
   }
 }
 
+void PhoneSet::LoadPhoneSetFromJson(const char* filename) {
+    nlohmann::json json_array;
+    std::ifstream file(filename);
+    if (file.is_open()) {
+        file >> json_array;
+        file.close();
+    } else {
+        LOG(INFO) << "Error loading token file, token file error or not exist.";
+        exit(-1);
+    }
+
+    int id = 0;
+    for (const auto& element : json_array) {
+        phone_.push_back(element);
+        phn2Id_.emplace(element, id);
+        id++;
+    }
+}
+
 int PhoneSet::Size() const {
   return phone_.size();
 }
diff --git a/runtime/onnxruntime/src/phone-set.h b/runtime/onnxruntime/src/phone-set.h
index 9729104..cb1a2a7 100644
--- a/runtime/onnxruntime/src/phone-set.h
+++ b/runtime/onnxruntime/src/phone-set.h
@@ -5,6 +5,7 @@
 #include <string>
 #include <vector>
 #include <unordered_map>
+#include "nlohmann/json.hpp"
 #define UNIT_BEG_SIL_SYMBOL "<s>"
 #define UNIT_END_SIL_SYMBOL "</s>"
 #define UNIT_BLK_SYMBOL "<blank>"
@@ -28,6 +29,7 @@
     vector<string> phone_;
     unordered_map<string, int> phn2Id_;
     void LoadPhoneSetFromYaml(const char* filename);
+    void LoadPhoneSetFromJson(const char* filename);
 };
 
 } // namespace funasr
diff --git a/runtime/onnxruntime/src/punc-model.cpp b/runtime/onnxruntime/src/punc-model.cpp
index 54b8d6a..9af03db 100644
--- a/runtime/onnxruntime/src/punc-model.cpp
+++ b/runtime/onnxruntime/src/punc-model.cpp
@@ -14,14 +14,16 @@
     }
     string punc_model_path;
     string punc_config_path;
+    string token_file;
 
     punc_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
     if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
         punc_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
     }
     punc_config_path = PathAppend(model_path.at(MODEL_DIR), PUNC_CONFIG_NAME);
+    token_file = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
 
-    mm->InitPunc(punc_model_path, punc_config_path, thread_num);
+    mm->InitPunc(punc_model_path, punc_config_path, token_file, thread_num);
     return mm;
 }
 
diff --git a/runtime/onnxruntime/src/tokenizer.cpp b/runtime/onnxruntime/src/tokenizer.cpp
index 7618282..06d64d8 100644
--- a/runtime/onnxruntime/src/tokenizer.cpp
+++ b/runtime/onnxruntime/src/tokenizer.cpp
@@ -127,6 +127,61 @@
 	return m_ready;
 }
 
+bool CTokenizer::OpenYaml(const char* sz_yamlfile, const char* token_file)
+{
+	YAML::Node m_Config;
+	try{
+		m_Config = YAML::LoadFile(sz_yamlfile);
+	}catch(exception const &e){
+        LOG(INFO) << "Error loading file, yaml file error or not exist.";
+        exit(-1);
+    }
+
+	try
+	{
+		YAML::Node conf_seg_jieba = m_Config["seg_jieba"];
+        if (conf_seg_jieba.IsDefined()){
+            seg_jieba = conf_seg_jieba.as<bool>();
+        }
+
+		auto Puncs = m_Config["model_conf"]["punc_list"];
+		if (Puncs.IsSequence())
+		{
+			for (size_t i = 0; i < Puncs.size(); ++i)
+			{
+				if (Puncs[i].IsScalar())
+				{ 
+					m_id2punc.push_back(Puncs[i].as<string>());
+					m_punc2id.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
+				}
+			}
+		}
+
+		nlohmann::json json_array;
+		std::ifstream file(token_file);
+		if (file.is_open()) {
+			file >> json_array;
+			file.close();
+		} else {
+			LOG(INFO) << "Error loading token file, token file error or not exist.";
+			return  false;
+		}
+
+		int i = 0;
+		for (const auto& element : json_array) {
+			m_id2token.push_back(element);
+			m_token2id[element] = i;
+			i++;
+		}
+	}
+	catch (YAML::BadFile& e) {
+		LOG(ERROR) << "Read error!";
+		return  false;
+	}
+	m_ready = true;
+	return m_ready;
+}
+
 vector<string> CTokenizer::Id2String(vector<int> input)
 {
 	vector<string> result;
diff --git a/runtime/onnxruntime/src/tokenizer.h b/runtime/onnxruntime/src/tokenizer.h
index 166061b..81aea7e 100644
--- a/runtime/onnxruntime/src/tokenizer.h
+++ b/runtime/onnxruntime/src/tokenizer.h
@@ -8,6 +8,7 @@
 #include "cppjieba/DictTrie.hpp"
 #include "cppjieba/HMMModel.hpp"
 #include "cppjieba/Jieba.hpp"
+#include "nlohmann/json.hpp"
 
 namespace funasr {
 class CTokenizer {
@@ -27,6 +28,7 @@
 	CTokenizer();
 	~CTokenizer();
 	bool OpenYaml(const char* sz_yamlfile);
+	bool OpenYaml(const char* sz_yamlfile, const char* token_file);
 	void ReadYaml(const YAML::Node& node);
 	vector<string> Id2String(vector<int> input);
 	vector<int> String2Ids(vector<string> input);
diff --git a/runtime/onnxruntime/src/tpass-stream.cpp b/runtime/onnxruntime/src/tpass-stream.cpp
index b723e0f..7681a4d 100644
--- a/runtime/onnxruntime/src/tpass-stream.cpp
+++ b/runtime/onnxruntime/src/tpass-stream.cpp
@@ -35,6 +35,7 @@
         string de_model_path;
         string am_cmvn_path;
         string am_config_path;
+        string token_path;
         string hw_compile_model_path;
         string seg_dict_path;
         
@@ -60,8 +61,9 @@
         }
         am_cmvn_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CMVN_NAME);
         am_config_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CONFIG_NAME);
+        token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
 
-        asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
+        asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
     }else{
         LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
         exit(-1);
@@ -85,20 +87,23 @@
     if(model_path.find(PUNC_DIR) != model_path.end()){
         string punc_model_path;
         string punc_config_path;
+        string token_path;
     
         punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
         if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
             punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
         }
         punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
+        token_path = PathAppend(model_path.at(PUNC_DIR), TOKEN_PATH);
 
         if (access(punc_model_path.c_str(), F_OK) != 0 ||
-            access(punc_config_path.c_str(), F_OK) != 0 )
+            access(punc_config_path.c_str(), F_OK) != 0 ||
+            access(token_path.c_str(), F_OK) != 0)
         {
             LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
         }else{
             punc_online_handle = make_unique<CTTransformerOnline>();
-            punc_online_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+            punc_online_handle->InitPunc(punc_model_path, punc_config_path, token_path, thread_num);
             use_punc = true;
         }
     }
diff --git a/runtime/onnxruntime/src/vocab.cpp b/runtime/onnxruntime/src/vocab.cpp
index 6991376..1416dd3 100644
--- a/runtime/onnxruntime/src/vocab.cpp
+++ b/runtime/onnxruntime/src/vocab.cpp
@@ -14,7 +14,7 @@
 Vocab::Vocab(const char *filename)
 {
     ifstream in(filename);
-    LoadVocabFromYaml(filename);
+    LoadVocabFromJson(filename);
 }
 Vocab::Vocab(const char *filename, const char *lex_file)
 {
@@ -43,6 +43,25 @@
     }
 }
 
+void Vocab::LoadVocabFromJson(const char* filename){
+    nlohmann::json json_array;
+    std::ifstream file(filename);
+    if (file.is_open()) {
+        file >> json_array;
+        file.close();
+    } else {
+        LOG(INFO) << "Error loading token file, token file error or not exist.";
+        exit(-1);
+    }
+
+    int i = 0;
+    for (const auto& element : json_array) {
+        vocab.push_back(element);
+        token_id[element] = i;
+        i++;
+    }
+}
+
 void Vocab::LoadLex(const char* filename){
     std::ifstream file(filename);
     std::string line;
diff --git a/runtime/onnxruntime/src/vocab.h b/runtime/onnxruntime/src/vocab.h
index 19e3648..36fabf4 100644
--- a/runtime/onnxruntime/src/vocab.h
+++ b/runtime/onnxruntime/src/vocab.h
@@ -6,6 +6,7 @@
 #include <string>
 #include <vector>
 #include <map>
+#include "nlohmann/json.hpp"
 using namespace std;
 
 namespace funasr {
@@ -16,6 +17,7 @@
     std::map<string, string> lex_map;
     bool IsEnglish(string ch);
     void LoadVocabFromYaml(const char* filename);
+    void LoadVocabFromJson(const char* filename);
     void LoadLex(const char* filename);
 
   public:
diff --git a/runtime/websocket/bin/funasr-wss-server-2pass.cpp b/runtime/websocket/bin/funasr-wss-server-2pass.cpp
index 4bc413c..d42679b 100644
--- a/runtime/websocket/bin/funasr-wss-server-2pass.cpp
+++ b/runtime/websocket/bin/funasr-wss-server-2pass.cpp
@@ -55,11 +55,11 @@
 
     TCLAP::ValueArg<std::string> offline_model_revision(
         "", "offline-model-revision", "ASR offline model revision", false,
-        "v1.2.1", "string");
+        "v2.0.4", "string");
 
     TCLAP::ValueArg<std::string> online_model_revision(
         "", "online-model-revision", "ASR online model revision", false,
-        "v1.0.6", "string");
+        "v2.0.4", "string");
 
     TCLAP::ValueArg<std::string> quantize(
         "", QUANTIZE,
@@ -73,7 +73,7 @@
         "model_quant.onnx, vad.yaml, vad.mvn",
         false, "damo/speech_fsmn_vad_zh-cn-16k-common-onnx", "string");
     TCLAP::ValueArg<std::string> vad_revision(
-        "", "vad-revision", "VAD model revision", false, "v1.2.0", "string");
+        "", "vad-revision", "VAD model revision", false, "v2.0.4", "string");
     TCLAP::ValueArg<std::string> vad_quant(
         "", VAD_QUANT,
         "true (Default), load the model of model_quant.onnx in vad_dir. If set "
@@ -85,7 +85,7 @@
         "model_quant.onnx, punc.yaml",
         false, "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx", "string");
     TCLAP::ValueArg<std::string> punc_revision(
-        "", "punc-revision", "PUNC model revision", false, "v1.0.2", "string");
+        "", "punc-revision", "PUNC model revision", false, "v2.0.4", "string");
     TCLAP::ValueArg<std::string> punc_quant(
         "", PUNC_QUANT,
         "true (Default), load the model of model_quant.onnx in punc_dir. If "
@@ -262,17 +262,17 @@
 
         size_t found = s_offline_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
         if (found != std::string::npos) {
-            model_path["offline-model-revision"]="v1.2.4";
+            model_path["offline-model-revision"]="v2.0.4";
         }
 
         found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
         if (found != std::string::npos) {
-            model_path["offline-model-revision"]="v1.0.5";
+            model_path["offline-model-revision"]="v2.0.5";
         }
 
         found = s_offline_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
         if (found != std::string::npos) {
-            model_path["model-revision"]="v1.0.0";
+            model_path["model-revision"]="v2.0.4";
             s_itn_path="";
             s_lm_path="";
         }
diff --git a/runtime/websocket/bin/funasr-wss-server.cpp b/runtime/websocket/bin/funasr-wss-server.cpp
index bff4f66..5bb7def 100644
--- a/runtime/websocket/bin/funasr-wss-server.cpp
+++ b/runtime/websocket/bin/funasr-wss-server.cpp
@@ -50,7 +50,7 @@
     TCLAP::ValueArg<std::string> model_revision(
         "", "model-revision",
         "ASR model revision",
-        false, "v1.2.1", "string");
+        false, "v2.0.4", "string");
     TCLAP::ValueArg<std::string> quantize(
         "", QUANTIZE,
         "true (Default), load the model of model_quant.onnx in model_dir. If set "
@@ -63,7 +63,7 @@
     TCLAP::ValueArg<std::string> vad_revision(
         "", "vad-revision",
         "VAD model revision",
-        false, "v1.2.0", "string");
+        false, "v2.0.4", "string");
     TCLAP::ValueArg<std::string> vad_quant(
         "", VAD_QUANT,
         "true (Default), load the model of model_quant.onnx in vad_dir. If set "
@@ -77,7 +77,7 @@
     TCLAP::ValueArg<std::string> punc_revision(
         "", "punc-revision",
         "PUNC model revision",
-        false, "v1.1.7", "string");
+        false, "v2.0.4", "string");
     TCLAP::ValueArg<std::string> punc_quant(
         "", PUNC_QUANT,
         "true (Default), load the model of model_quant.onnx in punc_dir. If set "
@@ -233,17 +233,17 @@
             // modify model-revision by model name
             size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
             if (found != std::string::npos) {
-                model_path["model-revision"]="v1.2.4";
+                model_path["model-revision"]="v2.0.4";
             }
 
             found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
             if (found != std::string::npos) {
-                model_path["model-revision"]="v1.0.5";
+                model_path["model-revision"]="v2.0.5";
             }
 
             found = s_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
             if (found != std::string::npos) {
-                model_path["model-revision"]="v1.0.0";
+                model_path["model-revision"]="v2.0.4";
                 s_itn_path="";
                 s_lm_path="";
             }

--
Gitblit v1.9.1