From 8912e0696af069de47646fdb8a9d9c4e086e88b3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 14 一月 2024 23:42:11 +0800
Subject: [PATCH] Resolve merge conflict

---
 runtime/onnxruntime/src/vocab.cpp |   61 ++++++++++++++++++++++++++----
 1 files changed, 52 insertions(+), 9 deletions(-)

diff --git a/runtime/onnxruntime/src/vocab.cpp b/runtime/onnxruntime/src/vocab.cpp
index d29281c..6991376 100644
--- a/runtime/onnxruntime/src/vocab.cpp
+++ b/runtime/onnxruntime/src/vocab.cpp
@@ -16,6 +16,12 @@
     ifstream in(filename);
     LoadVocabFromYaml(filename);
 }
+Vocab::Vocab(const char *filename, const char *lex_file)
+{
+    ifstream in(filename);
+    LoadVocabFromYaml(filename);
+    LoadLex(lex_file);
+}
 Vocab::~Vocab()
 {
 }
@@ -37,11 +43,37 @@
     }
 }
 
-int Vocab::GetIdByToken(const std::string &token) {
-    if (token_id.count(token)) {
-        return token_id[token];
+void Vocab::LoadLex(const char* filename){
+    std::ifstream file(filename);
+    std::string line;
+    while (std::getline(file, line)) {
+        std::string key, value;
+        std::istringstream iss(line);
+        std::getline(iss, key, '\t');
+        std::getline(iss, value);
+
+        if (!key.empty() && !value.empty()) {
+            lex_map[key] = value;
+        }
     }
-    return 0;
+
+    file.close();
+}
+
+string Vocab::Word2Lex(const std::string &word) const {
+    auto it = lex_map.find(word);
+    if (it != lex_map.end()) {
+        return it->second;
+    }
+    return "";
+}
+
+int Vocab::GetIdByToken(const std::string &token) const {
+    auto it = token_id.find(token);
+    if (it != token_id.end()) {
+        return it->second;
+    }
+    return -1;
 }
 
 void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
@@ -120,8 +152,8 @@
     std::string combine = "";
     std::string unicodeChar = "鈻�";
 
-    for (auto it = in.begin(); it != in.end(); it++) {
-        string word = vocab[*it];
+    for (i=0; i<in.size(); i++){
+        string word = vocab[in[i]];
         // step1 space character skips
         if (word == "<s>" || word == "</s>" || word == "<unk>")
             continue;
@@ -146,9 +178,20 @@
             int sub_word = !(word.find("@@") == string::npos);
             // process word start and middle part
             if (sub_word) {
-                combine += word.erase(word.length() - 2);
-                is_combining = true;
-                continue;
+                // if badcase: lo@@ chinese
+                if (i == in.size()-1 || i<in.size()-1 && IsChinese(vocab[in[i+1]])){
+                    word = word.erase(word.length() - 2) + " ";
+                    if (is_combining) {
+                        combine += word;
+                        is_combining = false;
+                        word = combine;
+                        combine = "";
+                    }
+                }else{
+                    combine += word.erase(word.length() - 2);
+                    is_combining = true;
+                    continue;
+                }
             }
             // process word end part
             else if (is_combining) {

--
Gitblit v1.9.1