From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/runtime/onnxruntime/src/fsmn-vad.h |   77 +++++++++++++++++++++-----------------
 1 files changed, 43 insertions(+), 34 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.h b/funasr/runtime/onnxruntime/src/fsmn-vad.h
index e8569f9..adceb1f 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.h
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.h
@@ -1,10 +1,15 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
 
 #ifndef VAD_SERVER_FSMNVAD_H
 #define VAD_SERVER_FSMNVAD_H
 
 #include "precomp.h"
 
-class FsmnVad {
+namespace funasr {
+class FsmnVad : public VadModel {
 /**
  * Author: Speech Lab of DAMO Academy, Alibaba Group
  * Deep-FSMN for Large Vocabulary Continuous Speech Recognition
@@ -13,33 +18,17 @@
 
 public:
     FsmnVad();
+    ~FsmnVad();
     void Test();
-    void InitVad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
-                  float vad_speech_noise_thres);
-
-    std::vector<std::vector<int>> Infer(const std::vector<float> &waves);
-    void Reset();
-
-private:
-
-    void ReadModel(const std::string &vad_model);
-
-    static void GetInputOutputInfo(
-            const std::shared_ptr<Ort::Session> &session,
-            std::vector<const char *> *in_names, std::vector<const char *> *out_names);
-
-    void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
-                    const std::vector<float> &waves);
-
-    std::vector<std::vector<float>> &LfrCmvn(std::vector<std::vector<float>> &vad_feats, int lfr_m, int lfr_n);
-
+    void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num);
+    std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true);
     void Forward(
-            const std::vector<std::vector<float>> &chunk_feats,
-            std::vector<std::vector<float>> *out_prob);
-
-    void LoadCmvn(const char *filename);
-    void InitCache();
-
+        const std::vector<std::vector<float>> &chunk_feats,
+        std::vector<std::vector<float>> *out_prob,
+        std::vector<std::vector<float>> *in_cache,
+        bool is_final);
+    void Reset();
+    
     std::shared_ptr<Ort::Session> vad_session_ = nullptr;
     Ort::Env env_;
     Ort::SessionOptions session_options_;
@@ -47,14 +36,34 @@
     std::vector<const char *> vad_out_names_;
     std::vector<std::vector<float>> in_cache_;
     
-    knf::FbankOptions fbank_opts;
-    std::vector<float> means_list;
-    std::vector<float> vars_list;
-    int vad_sample_rate_ = 16000;
-    int vad_silence_duration_ = 800;
-    int vad_max_len_ = 15000;
-    double vad_speech_noise_thres_ = 0.9;
+    knf::FbankOptions fbank_opts_;
+    std::vector<float> means_list_;
+    std::vector<float> vars_list_;
+
+    int vad_sample_rate_ = MODEL_SAMPLE_RATE;
+    int vad_silence_duration_ = VAD_SILENCE_DURATION;
+    int vad_max_len_ = VAD_MAX_LEN;
+    double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
+    int lfr_m = VAD_LFR_M;
+    int lfr_n = VAD_LFR_N;
+
+private:
+
+    void ReadModel(const char* vad_model);
+    void LoadConfigFromYaml(const char* filename);
+
+    static void GetInputOutputInfo(
+            const std::shared_ptr<Ort::Session> &session,
+            std::vector<const char *> *in_names, std::vector<const char *> *out_names);
+
+    void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
+                    std::vector<float> &waves);
+
+    void LfrCmvn(std::vector<std::vector<float>> &vad_feats);
+    void LoadCmvn(const char *filename);
+    void InitCache();
+
 };
 
-
+} // namespace funasr
 #endif //VAD_SERVER_FSMNVAD_H

--
Gitblit v1.9.1