From 9cc37eaa8af50db2ffad3fc02746547ef995a870 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 28 九月 2023 15:47:43 +0800
Subject: [PATCH] add postprocess for en-bpe
---
funasr/runtime/onnxruntime/src/vocab.cpp | 44 ++++++++++++++++++++++++++++++++++----------
1 files changed, 34 insertions(+), 10 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/vocab.cpp b/funasr/runtime/onnxruntime/src/vocab.cpp
index 70553df..95174c7 100644
--- a/funasr/runtime/onnxruntime/src/vocab.cpp
+++ b/funasr/runtime/onnxruntime/src/vocab.cpp
@@ -29,19 +29,27 @@
exit(-1);
}
YAML::Node myList = config["token_list"];
+ int i = 0;
for (YAML::const_iterator it = myList.begin(); it != myList.end(); ++it) {
vocab.push_back(it->as<string>());
+ token_id[it->as<string>()] = i;
+ i ++;
}
}
-string Vocab::Vector2String(vector<int> in)
-{
- int i;
- stringstream ss;
- for (auto it = in.begin(); it != in.end(); it++) {
- ss << vocab[*it];
+int Vocab::GetIdByToken(const std::string &token) {
+ if (token_id.count(token)) {
+ return token_id[token];
}
- return ss.str();
+ return 0;
+}
+
+void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
+{
+ for (auto it = in.begin(); it != in.end(); it++) {
+ string word = vocab[*it];
+ preds.emplace_back(word);
+ }
}
int Str2Int(string str)
@@ -67,20 +75,36 @@
return false;
}
-string Vocab::Vector2StringV2(vector<int> in)
+string Vocab::Vector2StringV2(vector<int> in, std::string language)
{
int i;
list<string> words;
int is_pre_english = false;
int pre_english_len = 0;
int is_combining = false;
- string combine = "";
+ std::string combine = "";
+ std::string unicodeChar = "鈻�";
for (auto it = in.begin(); it != in.end(); it++) {
string word = vocab[*it];
// step1 space character skips
if (word == "<s>" || word == "</s>" || word == "<unk>")
continue;
+ if (language == "en-bpe"){
+ size_t found = word.find(unicodeChar);
+ if(found != std::string::npos){
+ if (combine != ""){
+ if (words.size() != 0){
+ combine = " " + combine;
+ }
+ words.push_back(combine);
+ }
+ combine = word.substr(3);
+ }else{
+ combine += word;
+ }
+ continue;
+ }
// step2 combie phoneme to full word
{
int sub_word = !(word.find("@@") == string::npos);
@@ -152,4 +176,4 @@
return vocab.size();
}
-} // namespace funasr
\ No newline at end of file
+} // namespace funasr
--
Gitblit v1.9.1