From b6d0ab4bfba04037203b3b9f6a34951e1525f36a Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期一, 24 四月 2023 15:42:10 +0800
Subject: [PATCH] fix GreedySearch

---
 funasr/runtime/onnxruntime/include/com-define.h            |    1 -
 funasr/runtime/onnxruntime/src/model.cpp                   |    2 --
 funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp |    1 -
 funasr/runtime/onnxruntime/src/alignedmem.h                |    2 --
 funasr/runtime/onnxruntime/src/online-feature.h            |    5 +----
 funasr/runtime/onnxruntime/src/paraformer.cpp              |    6 +++---
 funasr/runtime/onnxruntime/src/paraformer.h                |    7 ++++++-
 funasr/runtime/onnxruntime/include/libfunasrapi.h          |    9 +--------
 funasr/runtime/onnxruntime/src/commonfunc.h                |    2 --
 9 files changed, 11 insertions(+), 24 deletions(-)

diff --git a/funasr/runtime/onnxruntime/include/com-define.h b/funasr/runtime/onnxruntime/include/com-define.h
index 72a843d..e2c22f4 100644
--- a/funasr/runtime/onnxruntime/include/com-define.h
+++ b/funasr/runtime/onnxruntime/include/com-define.h
@@ -28,7 +28,6 @@
 // punc
 #define PUNC_MODEL_FILE  "punc_model.onnx"
 #define PUNC_YAML_FILE "punc.yaml"
-
 #define UNK_CHAR "<unk>"
 
 #define  INPUT_NUM  2
diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h
index f426ffd..6b6e148 100644
--- a/funasr/runtime/onnxruntime/include/libfunasrapi.h
+++ b/funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -51,21 +51,14 @@
 
 // if not give a fn_callback ,it should be NULL 
 _FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-
 _FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-
 _FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
-
 _FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool use_vad=false, bool use_punc=false);
 
 _FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT result,int n_index);
-
-_FUNASRAPI const int		FunASRGetRetNumber(FUNASR_RESULT result);
-
+_FUNASRAPI const int	FunASRGetRetNumber(FUNASR_RESULT result);
 _FUNASRAPI void			FunASRFreeResult(FUNASR_RESULT result);
-
 _FUNASRAPI void			FunASRUninit(FUNASR_HANDLE handle);
-
 _FUNASRAPI const float	FunASRGetRetSnippetTime(FUNASR_RESULT result);
 
 #ifdef __cplusplus 
diff --git a/funasr/runtime/onnxruntime/src/alignedmem.h b/funasr/runtime/onnxruntime/src/alignedmem.h
index 7ac6987..e2b640a 100644
--- a/funasr/runtime/onnxruntime/src/alignedmem.h
+++ b/funasr/runtime/onnxruntime/src/alignedmem.h
@@ -2,8 +2,6 @@
 #ifndef ALIGNEDMEM_H
 #define ALIGNEDMEM_H
 
-
-
 extern void *AlignedMalloc(size_t alignment, size_t required_bytes);
 extern void AlignedFree(void *p);
 
diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h
index ade6922..fbbda74 100644
--- a/funasr/runtime/onnxruntime/src/commonfunc.h
+++ b/funasr/runtime/onnxruntime/src/commonfunc.h
@@ -33,7 +33,6 @@
         {
             auto t = session->GetInputNameAllocated(nIndex, allocator);
             inputName = t.get();
-
         }
     }
 }
@@ -45,7 +44,6 @@
         {
             auto t = session->GetOutputNameAllocated(nIndex, allocator);
             outputName = t.get();
-
         }
     }
 }
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
index aaf1276..1d822a0 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
@@ -58,7 +58,6 @@
         }else{
             cout <<"No return data!";
         }
-
     }
     {
         lock_guard<mutex> guard(mtx);
diff --git a/funasr/runtime/onnxruntime/src/model.cpp b/funasr/runtime/onnxruntime/src/model.cpp
index aead7c9..a582f82 100644
--- a/funasr/runtime/onnxruntime/src/model.cpp
+++ b/funasr/runtime/onnxruntime/src/model.cpp
@@ -3,8 +3,6 @@
 Model *CreateModel(const char *path, int thread_num, bool quantize, bool use_vad, bool use_punc)
 {
     Model *mm;
-
     mm = new paraformer::Paraformer(path, thread_num, quantize, use_vad, use_punc);
-
     return mm;
 }
diff --git a/funasr/runtime/onnxruntime/src/online-feature.h b/funasr/runtime/onnxruntime/src/online-feature.h
index abe9587..78245de 100644
--- a/funasr/runtime/onnxruntime/src/online-feature.h
+++ b/funasr/runtime/onnxruntime/src/online-feature.h
@@ -12,15 +12,12 @@
 
   void ExtractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
 
-
 private:
   void OnlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
-
   int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
-
+  
   static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
     int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
-
     if (frame_num >= 1 && sample_length >= frame_sample_length)
       return frame_num;
     else
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 493dd6d..72127f8 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -143,14 +143,14 @@
     }
 }
 
-string Paraformer::GreedySearch(float * in, int n_len )
+string Paraformer::GreedySearch(float * in, int n_len,  int64_t token_nums)
 {
     vector<int> hyps;
     int Tmax = n_len;
     for (int i = 0; i < Tmax; i++) {
         int max_idx;
         float max_val;
-        FindMax(in + i * 8404, 8404, max_val, max_idx);
+        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
         hyps.push_back(max_idx);
     }
 
@@ -238,7 +238,7 @@
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
         float* floatData = outputTensor[0].GetTensorMutableData<float>();
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
-        result = GreedySearch(floatData, *encoder_out_lens);
+        result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
     }
     catch (std::exception const &e)
     {
diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h
index e29a4a9..5301932 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.h
+++ b/funasr/runtime/onnxruntime/src/paraformer.h
@@ -9,6 +9,11 @@
 namespace paraformer {
 
     class Paraformer : public Model {
+    /**
+     * Author: Speech Lab of DAMO Academy, Alibaba Group
+     * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+     * https://arxiv.org/pdf/2206.08317.pdf
+    */
     private:
         //std::unique_ptr<knf::OnlineFbank> fbank_;
         knf::FbankOptions fbank_opts;
@@ -27,7 +32,7 @@
         vector<float> ApplyLfr(const vector<float> &in);
         void ApplyCmvn(vector<float> *v);
 
-        string GreedySearch( float* in, int n_len);
+        string GreedySearch( float* in, int n_len, int64_t token_nums);
 
         std::shared_ptr<Ort::Session> m_session;
         Ort::Env env_;

--
Gitblit v1.9.1