From d4aaa84ad16c2c862ffcb5d73bf7852c8ee90d24 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 21 三月 2024 14:17:22 +0800
Subject: [PATCH] fix func FunASRWfstDecoderInit

---
 runtime/onnxruntime/src/paraformer-torch.cpp |   55 +++++++++++++++++++++++++++++++++++++++++++++----------
 1 files changed, 45 insertions(+), 10 deletions(-)

diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp
index 1f15ec7..934a0ea 100644
--- a/runtime/onnxruntime/src/paraformer-torch.cpp
+++ b/runtime/onnxruntime/src/paraformer-torch.cpp
@@ -38,7 +38,7 @@
         LOG(ERROR) << "CUDA is not available! Please check your GPU settings";
         exit(-1);
     } else {
-        LOG(INFO) << "CUDA available! Running on GPU";
+        LOG(INFO) << "CUDA is available, running on GPU";
         device = at::kCUDA;
     }
     #endif
@@ -280,6 +280,7 @@
     paraformer_length.emplace_back(num_frames);
 
     torch::NoGradGuard no_grad;
+    model_->eval();
     torch::Tensor feats =
         torch::from_blob(wav_feats.data(),
                 {1, num_frames, feat_dim}, torch::kFloat).contiguous();
@@ -305,15 +306,49 @@
         am_scores = outputs[0].toTensor();
         valid_token_lens = outputs[1].toTensor();
         #endif
-        
-        if (lm_ == nullptr) {
-            result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
-        } else {
-            result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
-            if (input_finished) {
-                result = FinalizeDecode(wfst_decoder);
+        // timestamp
+        if(outputs.size() == 4){
+            torch::Tensor us_alphas_tensor;
+            torch::Tensor us_peaks_tensor;
+            #ifdef USE_GPU
+            us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
+            us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
+            #else
+            us_alphas_tensor = outputs[2].toTensor();
+            us_peaks_tensor = outputs[3].toTensor();
+            #endif
+
+            int us_alphas_shape_1 = us_alphas_tensor.size(1);
+            float* us_alphas_data = us_alphas_tensor.data_ptr<float>();
+            std::vector<float> us_alphas(us_alphas_shape_1);
+            for (int i = 0; i < us_alphas_shape_1; i++) {
+                us_alphas[i] = us_alphas_data[i];
             }
-        }        
+
+            int us_peaks_shape_1 = us_peaks_tensor.size(1);
+            float* us_peaks_data = us_peaks_tensor.data_ptr<float>();
+            std::vector<float> us_peaks(us_peaks_shape_1);
+            for (int i = 0; i < us_peaks_shape_1; i++) {
+                us_peaks[i] = us_peaks_data[i];
+            }
+			if (lm_ == nullptr) {
+                result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
+			} else {
+			    result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
+                if (input_finished) {
+                    result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
+                }
+			}
+        }else{
+            if (lm_ == nullptr) {
+                result = GreedySearch(am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
+            } else {
+                result = BeamSearch(wfst_decoder, am_scores[0].data_ptr<float>(), valid_token_lens[0].item<int>(), am_scores.size(2));
+                if (input_finished) {
+                    result = FinalizeDecode(wfst_decoder);
+                }
+            }
+        }
     }
     catch (std::exception const &e)
     {
@@ -324,7 +359,7 @@
 }
 
 std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
-    std::vector<std::vector<float>> result;
+    std::vector<std::vector<float>> result(1, std::vector<float>(10, 0.0f));
     return result;
 }
 

--
Gitblit v1.9.1