From 3c83d64c84602de055f503af7d4e2761c829ec2e Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 12 十二月 2023 11:11:02 +0800
Subject: [PATCH] fst: support eng hotword
---
runtime/onnxruntime/src/vocab.h | 6 ++
runtime/websocket/bin/funasr-wss-server.cpp | 2
runtime/onnxruntime/src/paraformer.h | 2
runtime/onnxruntime/include/com-define.h | 1
runtime/onnxruntime/src/vocab.cpp | 40 ++++++++++++++++++--
runtime/onnxruntime/include/model.h | 2
runtime/onnxruntime/src/offline-stream.cpp | 10 ++++-
runtime/onnxruntime/src/bias-lm.h | 30 ++++++++++-----
runtime/onnxruntime/src/paraformer.cpp | 5 +-
9 files changed, 76 insertions(+), 22 deletions(-)
diff --git a/runtime/onnxruntime/include/com-define.h b/runtime/onnxruntime/include/com-define.h
index 57908e6..a2745da 100644
--- a/runtime/onnxruntime/include/com-define.h
+++ b/runtime/onnxruntime/include/com-define.h
@@ -68,6 +68,7 @@
#define QUANT_DECODER_NAME "decoder_quant.onnx"
#define LM_FST_RES "TLG.fst"
+#define LEX_PATH "lexicon.txt"
// vad
#ifndef VAD_SILENCE_DURATION
diff --git a/runtime/onnxruntime/include/model.h b/runtime/onnxruntime/include/model.h
index 356fca3..7b58e92 100644
--- a/runtime/onnxruntime/include/model.h
+++ b/runtime/onnxruntime/include/model.h
@@ -15,7 +15,7 @@
virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
- virtual void InitLm(const std::string &lm_file, const std::string &lm_config){};
+ virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
virtual void InitFstDecoder(){};
virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
virtual std::string Rescoring() = 0;
diff --git a/runtime/onnxruntime/src/bias-lm.h b/runtime/onnxruntime/src/bias-lm.h
index e2d28a2..957197a 100644
--- a/runtime/onnxruntime/src/bias-lm.h
+++ b/runtime/onnxruntime/src/bias-lm.h
@@ -65,12 +65,17 @@
if (text.size() > 1) {
score = std::stof(text[1]);
}
- Utf8ToCharset(text[0], split_str);
+ SplitChiEngCharacters(text[0], split_str);
for (auto &str : split_str) {
- split_id.push_back(phn_set_.String2Id(str));
- if (!phn_set_.Find(str)) {
- is_oov = true;
- break;
+ std::vector<string> lex_vec;
+ std::string lex_str = vocab_.Word2Lex(str);
+ SplitStringToVector(lex_str, " ", true, &lex_vec);
+ for (auto &token : lex_vec) {
+ split_id.push_back(phn_set_.String2Id(token));
+ if (!phn_set_.Find(token)) {
+ is_oov = true;
+ break;
+ }
}
}
if (!is_oov) {
@@ -103,12 +108,17 @@
std::vector<std::string> split_str;
std::vector<int> split_id;
score = kv.second;
- Utf8ToCharset(kv.first, split_str);
+ SplitChiEngCharacters(kv.first, split_str);
for (auto &str : split_str) {
- split_id.push_back(phn_set_.String2Id(str));
- if (!phn_set_.Find(str)) {
- is_oov = true;
- break;
+ std::vector<string> lex_vec;
+ std::string lex_str = vocab_.Word2Lex(str);
+ SplitStringToVector(lex_str, " ", true, &lex_vec);
+ for (auto &token : lex_vec) {
+ split_id.push_back(phn_set_.String2Id(token));
+ if (!phn_set_.Find(token)) {
+ is_oov = true;
+ break;
+ }
}
}
if (!is_oov) {
diff --git a/runtime/onnxruntime/src/offline-stream.cpp b/runtime/onnxruntime/src/offline-stream.cpp
index 2709ca6..ae8cf18 100644
--- a/runtime/onnxruntime/src/offline-stream.cpp
+++ b/runtime/onnxruntime/src/offline-stream.cpp
@@ -63,10 +63,16 @@
// Lm resource
if (model_path.find(LM_DIR) != model_path.end() && model_path.at(LM_DIR) != "") {
- string fst_path, lm_config_path, hws_path;
+ string fst_path, lm_config_path, lex_path;
fst_path = PathAppend(model_path.at(LM_DIR), LM_FST_RES);
lm_config_path = PathAppend(model_path.at(LM_DIR), LM_CONFIG_NAME);
- asr_handle->InitLm(fst_path, lm_config_path);
+ lex_path = PathAppend(model_path.at(LM_DIR), LEX_PATH);
+ if (access(lex_path.c_str(), F_OK) != 0 )
+ {
+ LOG(ERROR) << "Lexicon.txt file is not exist, please use the latest version. Skip load LM model.";
+ }else{
+ asr_handle->InitLm(fst_path, lm_config_path, lex_path);
+ }
}
// PUNC model
diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index b3dc619..3de3e39 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -187,13 +187,14 @@
}
void Paraformer::InitLm(const std::string &lm_file,
- const std::string &lm_cfg_file) {
+ const std::string &lm_cfg_file,
+ const std::string &lex_file) {
try {
lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
fst::Fst<fst::StdArc>::Read(lm_file));
if (lm_){
if (vocab) { delete vocab; }
- vocab = new Vocab(lm_cfg_file.c_str());
+ vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
LOG(INFO) << "Successfully load lm file " << lm_file;
}else{
LOG(ERROR) << "Failed to load lm file " << lm_file;
diff --git a/runtime/onnxruntime/src/paraformer.h b/runtime/onnxruntime/src/paraformer.h
index b5bc46d..89c8b09 100644
--- a/runtime/onnxruntime/src/paraformer.h
+++ b/runtime/onnxruntime/src/paraformer.h
@@ -60,7 +60,7 @@
void StartUtterance();
void EndUtterance();
- void InitLm(const std::string &lm_file, const std::string &lm_cfg_file);
+ void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
string FinalizeDecode(WfstDecoder* &wfst_decoder,
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
diff --git a/runtime/onnxruntime/src/vocab.cpp b/runtime/onnxruntime/src/vocab.cpp
index 20571c9..6991376 100644
--- a/runtime/onnxruntime/src/vocab.cpp
+++ b/runtime/onnxruntime/src/vocab.cpp
@@ -16,6 +16,12 @@
ifstream in(filename);
LoadVocabFromYaml(filename);
}
+Vocab::Vocab(const char *filename, const char *lex_file)
+{
+ ifstream in(filename);
+ LoadVocabFromYaml(filename);
+ LoadLex(lex_file);
+}
Vocab::~Vocab()
{
}
@@ -37,11 +43,37 @@
}
}
-int Vocab::GetIdByToken(const std::string &token) {
- if (token_id.count(token)) {
- return token_id[token];
+void Vocab::LoadLex(const char* filename){
+ std::ifstream file(filename);
+ std::string line;
+ while (std::getline(file, line)) {
+ std::string key, value;
+ std::istringstream iss(line);
+ std::getline(iss, key, '\t');
+ std::getline(iss, value);
+
+ if (!key.empty() && !value.empty()) {
+ lex_map[key] = value;
+ }
}
- return 0;
+
+ file.close();
+}
+
+string Vocab::Word2Lex(const std::string &word) const {
+ auto it = lex_map.find(word);
+ if (it != lex_map.end()) {
+ return it->second;
+ }
+ return "";
+}
+
+int Vocab::GetIdByToken(const std::string &token) const {
+ auto it = token_id.find(token);
+ if (it != token_id.end()) {
+ return it->second;
+ }
+ return -1;
}
void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
diff --git a/runtime/onnxruntime/src/vocab.h b/runtime/onnxruntime/src/vocab.h
index 8834b97..19e3648 100644
--- a/runtime/onnxruntime/src/vocab.h
+++ b/runtime/onnxruntime/src/vocab.h
@@ -13,11 +13,14 @@
private:
vector<string> vocab;
std::map<string, int> token_id;
+ std::map<string, string> lex_map;
bool IsEnglish(string ch);
void LoadVocabFromYaml(const char* filename);
+ void LoadLex(const char* filename);
public:
Vocab(const char *filename);
+ Vocab(const char *filename, const char *lex_file);
~Vocab();
int Size() const;
bool IsChinese(string ch);
@@ -26,7 +29,8 @@
string Vector2StringV2(vector<int> in, std::string language="");
string Id2String(int id) const;
string WordFormat(std::string word);
- int GetIdByToken(const std::string &token);
+ int GetIdByToken(const std::string &token) const;
+ string Word2Lex(const std::string &word) const;
};
} // namespace funasr
diff --git a/runtime/websocket/bin/funasr-wss-server.cpp b/runtime/websocket/bin/funasr-wss-server.cpp
index f54bc5b..67e3309 100644
--- a/runtime/websocket/bin/funasr-wss-server.cpp
+++ b/runtime/websocket/bin/funasr-wss-server.cpp
@@ -111,7 +111,7 @@
TCLAP::ValueArg<std::string> lm_dir("", LM_DIR,
"the LM model path, which contains compiled models: TLG.fst, config.yaml ", false, "damo/speech_ngram_lm_zh-cn-ai-wesp-fst", "string");
TCLAP::ValueArg<std::string> lm_revision(
- "", "lm-revision", "LM model revision", false, "v1.0.1", "string");
+ "", "lm-revision", "LM model revision", false, "v1.0.2", "string");
TCLAP::ValueArg<std::string> hotword("", HOTWORD,
"the hotword file, one hotword perline, Format: Hotword Weight (could be: 闃块噷宸村反 20)",
false, "/workspace/resources/hotwords.txt", "string");
--
Gitblit v1.9.1