From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
runtime/onnxruntime/src/vocab.cpp | 82 ++++++++++++++++++++++++++++++++++++-----
1 files changed, 72 insertions(+), 10 deletions(-)
diff --git a/runtime/onnxruntime/src/vocab.cpp b/runtime/onnxruntime/src/vocab.cpp
index bdedd84..1416dd3 100644
--- a/runtime/onnxruntime/src/vocab.cpp
+++ b/runtime/onnxruntime/src/vocab.cpp
@@ -14,7 +14,13 @@
Vocab::Vocab(const char *filename)
{
ifstream in(filename);
+ LoadVocabFromJson(filename);
+}
+Vocab::Vocab(const char *filename, const char *lex_file)
+{
+ ifstream in(filename);
LoadVocabFromYaml(filename);
+ LoadLex(lex_file);
}
Vocab::~Vocab()
{
@@ -37,11 +43,56 @@
}
}
-int Vocab::GetIdByToken(const std::string &token) {
- if (token_id.count(token)) {
- return token_id[token];
+void Vocab::LoadVocabFromJson(const char* filename){
+ nlohmann::json json_array;
+ std::ifstream file(filename);
+ if (file.is_open()) {
+ file >> json_array;
+ file.close();
+ } else {
+ LOG(INFO) << "Error loading token file, token file error or not exist.";
+ exit(-1);
}
- return 0;
+
+ int i = 0;
+ for (const auto& element : json_array) {
+ vocab.push_back(element);
+ token_id[element] = i;
+ i++;
+ }
+}
+
+void Vocab::LoadLex(const char* filename){
+ std::ifstream file(filename);
+ std::string line;
+ while (std::getline(file, line)) {
+ std::string key, value;
+ std::istringstream iss(line);
+ std::getline(iss, key, '\t');
+ std::getline(iss, value);
+
+ if (!key.empty() && !value.empty()) {
+ lex_map[key] = value;
+ }
+ }
+
+ file.close();
+}
+
+string Vocab::Word2Lex(const std::string &word) const {
+ auto it = lex_map.find(word);
+ if (it != lex_map.end()) {
+ return it->second;
+ }
+ return "";
+}
+
+int Vocab::GetIdByToken(const std::string &token) const {
+ auto it = token_id.find(token);
+ if (it != token_id.end()) {
+ return it->second;
+ }
+ return -1;
}
void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
@@ -120,8 +171,8 @@
std::string combine = "";
std::string unicodeChar = "鈻�";
- for (auto it = in.begin(); it != in.end(); it++) {
- string word = vocab[*it];
+ for (i=0; i<in.size(); i++){
+ string word = vocab[in[i]];
// step1 space character skips
if (word == "<s>" || word == "</s>" || word == "<unk>")
continue;
@@ -146,9 +197,20 @@
int sub_word = !(word.find("@@") == string::npos);
// process word start and middle part
if (sub_word) {
- combine += word.erase(word.length() - 2);
- is_combining = true;
- continue;
+ // if badcase: lo@@ chinese
+ if (i == in.size()-1 || i<in.size()-1 && IsChinese(vocab[in[i+1]])){
+ word = word.erase(word.length() - 2) + " ";
+ if (is_combining) {
+ combine += word;
+ is_combining = false;
+ word = combine;
+ combine = "";
+ }
+ }else{
+ combine += word.erase(word.length() - 2);
+ is_combining = true;
+ continue;
+ }
}
// process word end part
else if (is_combining) {
@@ -199,7 +261,7 @@
}
}
- if (language == "en-bpe" and combine != ""){
+ if (language == "en-bpe" && combine != ""){
combine = WordFormat(combine);
if (words.size() != 0){
combine = " " + combine;
--
Gitblit v1.9.1