From 1d7bbbffb6a024a33859b48a7a656d0455dc0be1 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 16 十月 2023 11:47:59 +0800
Subject: [PATCH] Update README.md

---
 funasr/runtime/onnxruntime/src/vocab.cpp |   56 +++++++++++++++++++++++++++++++++++++++++++++++++++++---
 1 files changed, 53 insertions(+), 3 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/vocab.cpp b/funasr/runtime/onnxruntime/src/vocab.cpp
index dc40978..2babc40 100644
--- a/funasr/runtime/onnxruntime/src/vocab.cpp
+++ b/funasr/runtime/onnxruntime/src/vocab.cpp
@@ -29,9 +29,19 @@
         exit(-1);
     }
     YAML::Node myList = config["token_list"];
+    int i = 0;
     for (YAML::const_iterator it = myList.begin(); it != myList.end(); ++it) {
         vocab.push_back(it->as<string>());
+        token_id[it->as<string>()] = i;
+        i ++;
     }
+}
+
+int Vocab::GetIdByToken(const std::string &token) {
+    if (token_id.count(token)) {
+        return token_id[token];
+    }
+    return 0;
 }
 
 void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
@@ -65,20 +75,52 @@
     return false;
 }
 
-string Vocab::Vector2StringV2(vector<int> in)
+string Vocab::WordFormat(std::string word)
+{
+    if(word == "i"){
+        return "I";
+    }else if(word == "i'm"){
+        return "I'm";
+    }else if(word == "i've"){
+        return "I've";
+    }else if(word == "i'll"){
+        return "I'll";
+    }else{
+        return word;
+    }
+}
+
+string Vocab::Vector2StringV2(vector<int> in, std::string language)
 {
     int i;
     list<string> words;
     int is_pre_english = false;
     int pre_english_len = 0;
     int is_combining = false;
-    string combine = "";
+    std::string combine = "";
+    std::string unicodeChar = "鈻�";
 
     for (auto it = in.begin(); it != in.end(); it++) {
         string word = vocab[*it];
         // step1 space character skips
         if (word == "<s>" || word == "</s>" || word == "<unk>")
             continue;
+        if (language == "en-bpe"){
+            size_t found = word.find(unicodeChar);
+            if(found != std::string::npos){
+                if (combine != ""){
+                    combine = WordFormat(combine);
+                    if (words.size() != 0){
+                        combine = " " + combine;
+                    }
+                    words.push_back(combine);
+                }
+                combine = word.substr(3);
+            }else{
+                combine += word;
+            }
+            continue;
+        }
         // step2 combie phoneme to full word
         {
             int sub_word = !(word.find("@@") == string::npos);
@@ -137,6 +179,14 @@
         }
     }
 
+    if (language == "en-bpe" and combine != ""){
+        combine = WordFormat(combine);
+        if (words.size() != 0){
+            combine = " " + combine;
+        }
+        words.push_back(combine);
+    }
+
     stringstream ss;
     for (auto it = words.begin(); it != words.end(); it++) {
         ss << *it;
@@ -150,4 +200,4 @@
     return vocab.size();
 }
 
-} // namespace funasr
\ No newline at end of file
+} // namespace funasr

--
Gitblit v1.9.1