From b6d0ab4bfba04037203b3b9f6a34951e1525f36a Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期一, 24 四月 2023 15:42:10 +0800
Subject: [PATCH] fix GreedySearch
---
funasr/runtime/onnxruntime/include/com-define.h | 1 -
funasr/runtime/onnxruntime/src/model.cpp | 2 --
funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp | 1 -
funasr/runtime/onnxruntime/src/alignedmem.h | 2 --
funasr/runtime/onnxruntime/src/online-feature.h | 5 +----
funasr/runtime/onnxruntime/src/paraformer.cpp | 6 +++---
funasr/runtime/onnxruntime/src/paraformer.h | 7 ++++++-
funasr/runtime/onnxruntime/include/libfunasrapi.h | 9 +--------
funasr/runtime/onnxruntime/src/commonfunc.h | 2 --
9 files changed, 11 insertions(+), 24 deletions(-)
diff --git a/funasr/runtime/onnxruntime/include/com-define.h b/funasr/runtime/onnxruntime/include/com-define.h
index 72a843d..e2c22f4 100644
--- a/funasr/runtime/onnxruntime/include/com-define.h
+++ b/funasr/runtime/onnxruntime/include/com-define.h
@@ -28,7 +28,6 @@
// punc
#define PUNC_MODEL_FILE "punc_model.onnx"
#define PUNC_YAML_FILE "punc.yaml"
-
#define UNK_CHAR "<unk>"
#define INPUT_NUM 2
diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h
index f426ffd..6b6e148 100644
--- a/funasr/runtime/onnxruntime/include/libfunasrapi.h
+++ b/funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -51,21 +51,14 @@
// if not give a fn_callback ,it should be NULL
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-
_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index);
-
-_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
-
+_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result);
-
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
-
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result);
#ifdef __cplusplus
diff --git a/funasr/runtime/onnxruntime/src/alignedmem.h b/funasr/runtime/onnxruntime/src/alignedmem.h
index 7ac6987..e2b640a 100644
--- a/funasr/runtime/onnxruntime/src/alignedmem.h
+++ b/funasr/runtime/onnxruntime/src/alignedmem.h
@@ -2,8 +2,6 @@
#ifndef ALIGNEDMEM_H
#define ALIGNEDMEM_H
-
-
extern void *AlignedMalloc(size_t alignment, size_t required_bytes);
extern void AlignedFree(void *p);
diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h
index ade6922..fbbda74 100644
--- a/funasr/runtime/onnxruntime/src/commonfunc.h
+++ b/funasr/runtime/onnxruntime/src/commonfunc.h
@@ -33,7 +33,6 @@
{
auto t = session->GetInputNameAllocated(nIndex, allocator);
inputName = t.get();
-
}
}
}
@@ -45,7 +44,6 @@
{
auto t = session->GetOutputNameAllocated(nIndex, allocator);
outputName = t.get();
-
}
}
}
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
index aaf1276..1d822a0 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
@@ -58,7 +58,6 @@
}else{
cout <<"No return data!";
}
-
}
{
lock_guard<mutex> guard(mtx);
diff --git a/funasr/runtime/onnxruntime/src/model.cpp b/funasr/runtime/onnxruntime/src/model.cpp
index aead7c9..a582f82 100644
--- a/funasr/runtime/onnxruntime/src/model.cpp
+++ b/funasr/runtime/onnxruntime/src/model.cpp
@@ -3,8 +3,6 @@
Model *CreateModel(const char *path, int thread_num, bool quantize, bool use_vad, bool use_punc)
{
Model *mm;
-
mm = new paraformer::Paraformer(path, thread_num, quantize, use_vad, use_punc);
-
return mm;
}
diff --git a/funasr/runtime/onnxruntime/src/online-feature.h b/funasr/runtime/onnxruntime/src/online-feature.h
index abe9587..78245de 100644
--- a/funasr/runtime/onnxruntime/src/online-feature.h
+++ b/funasr/runtime/onnxruntime/src/online-feature.h
@@ -12,15 +12,12 @@
void ExtractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
-
private:
void OnlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
-
int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
-
+
static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
-
if (frame_num >= 1 && sample_length >= frame_sample_length)
return frame_num;
else
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 493dd6d..72127f8 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -143,14 +143,14 @@
}
}
-string Paraformer::GreedySearch(float * in, int n_len )
+string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums)
{
vector<int> hyps;
int Tmax = n_len;
for (int i = 0; i < Tmax; i++) {
int max_idx;
float max_val;
- FindMax(in + i * 8404, 8404, max_val, max_idx);
+ FindMax(in + i * token_nums, token_nums, max_val, max_idx);
hyps.push_back(max_idx);
}
@@ -238,7 +238,7 @@
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
float* floatData = outputTensor[0].GetTensorMutableData<float>();
auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
- result = GreedySearch(floatData, *encoder_out_lens);
+ result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
}
catch (std::exception const &e)
{
diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h
index e29a4a9..5301932 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.h
+++ b/funasr/runtime/onnxruntime/src/paraformer.h
@@ -9,6 +9,11 @@
namespace paraformer {
class Paraformer : public Model {
+ /**
+ * Author: Speech Lab of DAMO Academy, Alibaba Group
+ * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ * https://arxiv.org/pdf/2206.08317.pdf
+ */
private:
//std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions fbank_opts;
@@ -27,7 +32,7 @@
vector<float> ApplyLfr(const vector<float> &in);
void ApplyCmvn(vector<float> *v);
- string GreedySearch( float* in, int n_len);
+ string GreedySearch( float* in, int n_len, int64_t token_nums);
std::shared_ptr<Ort::Session> m_session;
Ort::Env env_;
--
Gitblit v1.9.1