From b6d0ab4bfba04037203b3b9f6a34951e1525f36a Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期一, 24 四月 2023 15:42:10 +0800
Subject: [PATCH] fix GreedySearch

---
 funasr/runtime/onnxruntime/src/paraformer.cpp |    6 +++---
 1 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 493dd6d..72127f8 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -143,14 +143,14 @@
     }
 }
 
-string Paraformer::GreedySearch(float * in, int n_len )
+string Paraformer::GreedySearch(float * in, int n_len,  int64_t token_nums)
 {
     vector<int> hyps;
     int Tmax = n_len;
     for (int i = 0; i < Tmax; i++) {
         int max_idx;
         float max_val;
-        FindMax(in + i * 8404, 8404, max_val, max_idx);
+        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
         hyps.push_back(max_idx);
     }
 
@@ -238,7 +238,7 @@
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
         float* floatData = outputTensor[0].GetTensorMutableData<float>();
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
-        result = GreedySearch(floatData, *encoder_out_lens);
+        result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
     }
     catch (std::exception const &e)
     {

--
Gitblit v1.9.1