From 3c83d64c84602de055f503af7d4e2761c829ec2e Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 12 十二月 2023 11:11:02 +0800
Subject: [PATCH] fst: support eng hotword
---
runtime/onnxruntime/src/vocab.cpp | 40 ++++++++++++++++++++++++++++++++++++----
1 files changed, 36 insertions(+), 4 deletions(-)
diff --git a/runtime/onnxruntime/src/vocab.cpp b/runtime/onnxruntime/src/vocab.cpp
index 20571c9..6991376 100644
--- a/runtime/onnxruntime/src/vocab.cpp
+++ b/runtime/onnxruntime/src/vocab.cpp
@@ -16,6 +16,12 @@
ifstream in(filename);
LoadVocabFromYaml(filename);
}
+Vocab::Vocab(const char *filename, const char *lex_file)
+{
+ ifstream in(filename);
+ LoadVocabFromYaml(filename);
+ LoadLex(lex_file);
+}
Vocab::~Vocab()
{
}
@@ -37,11 +43,37 @@
}
}
-int Vocab::GetIdByToken(const std::string &token) {
- if (token_id.count(token)) {
- return token_id[token];
+void Vocab::LoadLex(const char* filename){
+ std::ifstream file(filename);
+ std::string line;
+ while (std::getline(file, line)) {
+ std::string key, value;
+ std::istringstream iss(line);
+ std::getline(iss, key, '\t');
+ std::getline(iss, value);
+
+ if (!key.empty() && !value.empty()) {
+ lex_map[key] = value;
+ }
}
- return 0;
+
+ file.close();
+}
+
+string Vocab::Word2Lex(const std::string &word) const {
+ auto it = lex_map.find(word);
+ if (it != lex_map.end()) {
+ return it->second;
+ }
+ return "";
+}
+
+int Vocab::GetIdByToken(const std::string &token) const {
+ auto it = token_id.find(token);
+ if (it != token_id.end()) {
+ return it->second;
+ }
+ return -1;
}
void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
--
Gitblit v1.9.1