From 7ab2e5cf22bbb31808bcacf84c054c710e4e6a93 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 24 四月 2023 16:19:17 +0800
Subject: [PATCH] Merge pull request #400 from alibaba-damo-academy/dev_knf

---
 funasr/runtime/onnxruntime/src/paraformer.h |   58 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 58 insertions(+), 0 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h
new file mode 100644
index 0000000..5301932
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/paraformer.h
@@ -0,0 +1,58 @@
+#pragma once
+
+
+#ifndef PARAFORMER_MODELIMP_H
+#define PARAFORMER_MODELIMP_H
+
+#include "precomp.h"
+
+namespace paraformer {
+
+    class Paraformer : public Model {
+    /**
+     * Author: Speech Lab of DAMO Academy, Alibaba Group
+     * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+     * https://arxiv.org/pdf/2206.08317.pdf
+    */
+    private:
+        //std::unique_ptr<knf::OnlineFbank> fbank_;
+        knf::FbankOptions fbank_opts;
+
+        std::unique_ptr<FsmnVad> vad_handle;
+        std::unique_ptr<CTTransformer> punc_handle;
+
+        Vocab* vocab;
+        vector<float> means_list;
+        vector<float> vars_list;
+        const float scale = 22.6274169979695;
+        int32_t lfr_window_size = 7;
+        int32_t lfr_window_shift = 6;
+
+        void LoadCmvn(const char *filename);
+        vector<float> ApplyLfr(const vector<float> &in);
+        void ApplyCmvn(vector<float> *v);
+
+        string GreedySearch( float* in, int n_len, int64_t token_nums);
+
+        std::shared_ptr<Ort::Session> m_session;
+        Ort::Env env_;
+        Ort::SessionOptions session_options;
+
+        vector<string> m_strInputNames, m_strOutputNames;
+        vector<const char*> m_szInputNames;
+        vector<const char*> m_szOutputNames;
+
+    public:
+        Paraformer(const char* path, int thread_num=0, bool quantize=false, bool use_vad=false, bool use_punc=false);
+        ~Paraformer();
+        void Reset();
+        vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
+        string ForwardChunk(float* din, int len, int flag);
+        string Forward(float* din, int len, int flag);
+        string Rescoring();
+        std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data);
+        string AddPunc(const char* sz_input);
+    };
+
+} // namespace paraformer
+#endif

--
Gitblit v1.9.1