From 9cc37eaa8af50db2ffad3fc02746547ef995a870 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 28 九月 2023 15:47:43 +0800
Subject: [PATCH] add postprocess for en-bpe

---
 funasr/runtime/onnxruntime/src/paraformer.cpp |   35 +++++++++++++++++++++++------------
 funasr/runtime/onnxruntime/src/vocab.h        |    2 +-
 funasr/runtime/onnxruntime/src/paraformer.h   |    2 ++
 funasr/runtime/onnxruntime/src/vocab.cpp      |   20 ++++++++++++++++++--
 4 files changed, 44 insertions(+), 15 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index dfa2b1f..763d01e 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -65,6 +65,7 @@
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
     vocab = new Vocab(am_config.c_str());
+    LoadConfigFromYaml(am_config.c_str());
     LoadCmvn(am_cmvn.c_str());
 }
 
@@ -181,6 +182,27 @@
         m_szInputNames.push_back(item.c_str());
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
+}
+
+void Paraformer::LoadConfigFromYaml(const char* filename){
+
+    YAML::Node config;
+    try{
+        config = YAML::LoadFile(filename);
+    }catch(exception const &e){
+        LOG(ERROR) << "Error loading file, yaml file error or not exist.";
+        exit(-1);
+    }
+
+    try{
+        YAML::Node lang_conf = config["lang"];
+        if (lang_conf.IsDefined()){
+            language = lang_conf.as<string>();
+        }
+    }catch(exception const &e){
+        LOG(ERROR) << "Error when load argument from vad config YAML.";
+        exit(-1);
+    }
 }
 
 void Paraformer::LoadOnlineConfigFromYaml(const char* filename){
@@ -342,7 +364,7 @@
         hyps.push_back(max_idx);
     }
     if(!is_stamp){
-        return vocab->Vector2StringV2(hyps);
+        return vocab->Vector2StringV2(hyps, language);
     }else{
         std::vector<string> char_list;
         std::vector<std::vector<float>> timestamp_list;
@@ -707,17 +729,6 @@
         }else{
             result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
         }
-//         int pos = 0;
-//         std::vector<std::vector<float>> logits;
-//         for (int j = 0; j < outputShape[1]; j++)
-//         {
-//             std::vector<float> vec_token;
-//             vec_token.insert(vec_token.begin(), floatData + pos, floatData + pos + outputShape[2]);
-//             logits.push_back(vec_token);
-//             pos += outputShape[2];
-//         }
-//         //PrintMat(logits, "logits_out");
-//         result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
     }
     catch (std::exception const &e)
     {
diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h
index 4080881..bac8fad 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.h
+++ b/funasr/runtime/onnxruntime/src/paraformer.h
@@ -20,6 +20,7 @@
         //const float scale = 22.6274169979695;
         const float scale = 1.0;
 
+        void LoadConfigFromYaml(const char* filename);
         void LoadOnlineConfigFromYaml(const char* filename);
         void LoadCmvn(const char *filename);
         vector<float> ApplyLfr(const vector<float> &in);
@@ -32,6 +33,7 @@
         vector<const char*> hw_m_szInputNames;
         vector<const char*> hw_m_szOutputNames;
         bool use_hotword;
+        std::string language="zh-cn";
 
     public:
         Paraformer();
diff --git a/funasr/runtime/onnxruntime/src/vocab.cpp b/funasr/runtime/onnxruntime/src/vocab.cpp
index c29156f..95174c7 100644
--- a/funasr/runtime/onnxruntime/src/vocab.cpp
+++ b/funasr/runtime/onnxruntime/src/vocab.cpp
@@ -75,20 +75,36 @@
     return false;
 }
 
-string Vocab::Vector2StringV2(vector<int> in)
+string Vocab::Vector2StringV2(vector<int> in, std::string language)
 {
     int i;
     list<string> words;
     int is_pre_english = false;
     int pre_english_len = 0;
     int is_combining = false;
-    string combine = "";
+    std::string combine = "";
+    std::string unicodeChar = "鈻�";
 
     for (auto it = in.begin(); it != in.end(); it++) {
         string word = vocab[*it];
         // step1 space character skips
         if (word == "<s>" || word == "</s>" || word == "<unk>")
             continue;
+        if (language == "en-bpe"){
+            size_t found = word.find(unicodeChar);
+            if(found != std::string::npos){
+                if (combine != ""){
+                    if (words.size() != 0){
+                        combine = " " + combine;
+                    }
+                    words.push_back(combine);
+                }
+                combine = word.substr(3);
+            }else{
+                combine += word;
+            }
+            continue;
+        }
         // step2 combie phoneme to full word
         {
             int sub_word = !(word.find("@@") == string::npos);
diff --git a/funasr/runtime/onnxruntime/src/vocab.h b/funasr/runtime/onnxruntime/src/vocab.h
index 808852a..eecb9c8 100644
--- a/funasr/runtime/onnxruntime/src/vocab.h
+++ b/funasr/runtime/onnxruntime/src/vocab.h
@@ -22,7 +22,7 @@
     int Size();
     bool IsChinese(string ch);
     void Vector2String(vector<int> in, std::vector<std::string> &preds);
-    string Vector2StringV2(vector<int> in);
+    string Vector2StringV2(vector<int> in, std::string language="");
     int GetIdByToken(const std::string &token);
 };
 

--
Gitblit v1.9.1