From bf4b3ef9cb95acaa2b92b98f236c4f3228cdbc2d Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 21 九月 2023 16:30:43 +0800
Subject: [PATCH] Merge pull request #976 from alibaba-damo-academy/dev_lhn
---
funasr/runtime/onnxruntime/src/vocab.cpp | 24 ++++++++++++++++--------
1 files changed, 16 insertions(+), 8 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/vocab.cpp b/funasr/runtime/onnxruntime/src/vocab.cpp
index 70553df..c29156f 100644
--- a/funasr/runtime/onnxruntime/src/vocab.cpp
+++ b/funasr/runtime/onnxruntime/src/vocab.cpp
@@ -29,19 +29,27 @@
exit(-1);
}
YAML::Node myList = config["token_list"];
+ int i = 0;
for (YAML::const_iterator it = myList.begin(); it != myList.end(); ++it) {
vocab.push_back(it->as<string>());
+ token_id[it->as<string>()] = i;
+ i ++;
}
}
-string Vocab::Vector2String(vector<int> in)
-{
- int i;
- stringstream ss;
- for (auto it = in.begin(); it != in.end(); it++) {
- ss << vocab[*it];
+int Vocab::GetIdByToken(const std::string &token) {
+ if (token_id.count(token)) {
+ return token_id[token];
}
- return ss.str();
+ return 0;
+}
+
+void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
+{
+ for (auto it = in.begin(); it != in.end(); it++) {
+ string word = vocab[*it];
+ preds.emplace_back(word);
+ }
}
int Str2Int(string str)
@@ -152,4 +160,4 @@
return vocab.size();
}
-} // namespace funasr
\ No newline at end of file
+} // namespace funasr
--
Gitblit v1.9.1