From 2f27b165559cd53afab52047309ebe4ac838ebb8 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 14 五月 2024 09:54:08 +0800
Subject: [PATCH] Add files via upload

---
 runtime/onnxruntime/src/paraformer.cpp |   27 +++++++++++++++++----------
 1 files changed, 17 insertions(+), 10 deletions(-)

diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index bb15ac7..a57fb9b 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -18,7 +18,7 @@
 }
 
 // offline
-void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
     LoadConfigFromYaml(am_config.c_str());
     // knf options
     fbank_opts_.frame_opts.dither = 0;
@@ -65,13 +65,13 @@
         m_szInputNames.push_back(item.c_str());
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
-    vocab = new Vocab(am_config.c_str());
-	phone_set_ = new PhoneSet(am_config.c_str());
+    vocab = new Vocab(token_file.c_str());
+	phone_set_ = new PhoneSet(token_file.c_str());
     LoadCmvn(am_cmvn.c_str());
 }
 
 // online
-void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
     
     LoadOnlineConfigFromYaml(am_config.c_str());
     // knf options
@@ -143,15 +143,15 @@
     for (auto& item : de_strOutputNames)
         de_szOutputNames_.push_back(item.c_str());
 
-    vocab = new Vocab(am_config.c_str());
-    phone_set_ = new PhoneSet(am_config.c_str());
+    vocab = new Vocab(token_file.c_str());
+    phone_set_ = new PhoneSet(token_file.c_str());
     LoadCmvn(am_cmvn.c_str());
 }
 
 // 2pass
-void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
     // online
-    InitAsr(en_model, de_model, am_cmvn, am_config, thread_num);
+    InitAsr(en_model, de_model, am_cmvn, am_config, token_file, thread_num);
 
     // offline
     try {
@@ -193,8 +193,7 @@
         lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
             fst::Fst<fst::StdArc>::Read(lm_file));
         if (lm_){
-            if (vocab) { delete vocab; }
-            vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
+            lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
             LOG(INFO) << "Successfully load lm file " << lm_file;
         }else{
             LOG(ERROR) << "Failed to load lm file " << lm_file;
@@ -309,6 +308,9 @@
 {
     if(vocab){
         delete vocab;
+    }
+    if(lm_vocab){
+        delete lm_vocab;
     }
     if(seg_dict){
         delete seg_dict;
@@ -687,6 +689,11 @@
     return vocab;
 }
 
+Vocab* Paraformer::GetLmVocab()
+{
+    return lm_vocab;
+}
+
 PhoneSet* Paraformer::GetPhoneSet()
 {
     return phone_set_;

--
Gitblit v1.9.1