From a64b7d8d8aeb2bb543ca703045a45f42470e9a63 Mon Sep 17 00:00:00 2001
From: 彭震东 <zhendong.peng@qq.com>
Date: 星期四, 30 五月 2024 15:12:53 +0800
Subject: [PATCH] keep empty speech result (#1772)

---
 runtime/onnxruntime/src/paraformer-torch.cpp |   22 +++++++++++++++-------
 1 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp
index e7fbadf..a5f7194 100644
--- a/runtime/onnxruntime/src/paraformer-torch.cpp
+++ b/runtime/onnxruntime/src/paraformer-torch.cpp
@@ -16,7 +16,7 @@
 }
 
 // offline
-void ParaformerTorch::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+void ParaformerTorch::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
     LoadConfigFromYaml(am_config.c_str());
     // knf options
     fbank_opts_.frame_opts.dither = 0;
@@ -28,8 +28,8 @@
     fbank_opts_.energy_floor = 0;
     fbank_opts_.mel_opts.debug_mel = false;
 
-    vocab = new Vocab(am_config.c_str());
-	phone_set_ = new PhoneSet(am_config.c_str());
+    vocab = new Vocab(token_file.c_str());
+	phone_set_ = new PhoneSet(token_file.c_str());
     LoadCmvn(am_cmvn.c_str());
 
     torch::DeviceType device = at::kCPU;
@@ -281,13 +281,18 @@
         if(asr_feats.size() != 0){
             LfrCmvn(asr_feats);
         }
-        feats_batch.emplace_back(asr_feats);
-        int32_t num_frames  = asr_feats.size() / feature_dim;
+        int32_t num_frames  = asr_feats.size();
         paraformer_length.emplace_back(num_frames);
-        if(max_size < asr_feats.size()){
-            max_size = asr_feats.size();
+        if(max_size < asr_feats.size()*feature_dim){
+            max_size = asr_feats.size()*feature_dim;
             max_frames = num_frames;
         }
+
+        std::vector<float> flattened;
+        for (const auto& sub_vector : asr_feats) {
+            flattened.insert(flattened.end(), sub_vector.begin(), sub_vector.end());
+        }
+        feats_batch.emplace_back(flattened);
     }
 
     torch::NoGradGuard no_grad;
@@ -368,6 +373,9 @@
                 }
             }
             results.push_back(result);
+			if (wfst_decoder){
+				wfst_decoder->StartUtterance();
+			}
         }
     }
     catch (std::exception const &e)

--
Gitblit v1.9.1