From b6d0ab4bfba04037203b3b9f6a34951e1525f36a Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期一, 24 四月 2023 15:42:10 +0800
Subject: [PATCH] fix GreedySearch
---
funasr/runtime/onnxruntime/src/paraformer.cpp | 6 +++---
1 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 493dd6d..72127f8 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -143,14 +143,14 @@
}
}
-string Paraformer::GreedySearch(float * in, int n_len )
+string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums)
{
vector<int> hyps;
int Tmax = n_len;
for (int i = 0; i < Tmax; i++) {
int max_idx;
float max_val;
- FindMax(in + i * 8404, 8404, max_val, max_idx);
+ FindMax(in + i * token_nums, token_nums, max_val, max_idx);
hyps.push_back(max_idx);
}
@@ -238,7 +238,7 @@
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
float* floatData = outputTensor[0].GetTensorMutableData<float>();
auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
- result = GreedySearch(floatData, *encoder_out_lens);
+ result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
}
catch (std::exception const &e)
{
--
Gitblit v1.9.1