From 911d450a596a711d6faea37c2abfba13d3a511fd Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 27 四月 2023 14:15:11 +0800
Subject: [PATCH] Merge branch 'dev_lhn' into dev_websocket

---
 funasr/runtime/onnxruntime/src/fsmn-vad.cpp |  137 +++++++++++++++++++++++++++------------------
 1 files changed, 82 insertions(+), 55 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
index 0f87cb2..fbb682b 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -1,43 +1,63 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
 
 #include <fstream>
 #include "precomp.h"
-//#include "glog/logging.h"
 
-
-void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
-                       float vad_speech_noise_thres) {
+void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config) {
     session_options_.SetIntraOpNumThreads(1);
     session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
     session_options_.DisableCpuMemArena();
-    this->vad_sample_rate_ = vad_sample_rate;
-    this->vad_silence_duration_=vad_silence_duration;
-    this->vad_max_len_=vad_max_len;
-    this->vad_speech_noise_thres_=vad_speech_noise_thres;
 
-    ReadModel(vad_model);
+    ReadModel(vad_model.c_str());
     LoadCmvn(vad_cmvn.c_str());
+    LoadConfigFromYaml(vad_config.c_str());
     InitCache();
-
-    fbank_opts.frame_opts.dither = 0;
-    fbank_opts.mel_opts.num_bins = 80;
-    fbank_opts.frame_opts.samp_freq = vad_sample_rate;
-    fbank_opts.frame_opts.window_type = "hamming";
-    fbank_opts.frame_opts.frame_shift_ms = 10;
-    fbank_opts.frame_opts.frame_length_ms = 25;
-    fbank_opts.energy_floor = 0;
-    fbank_opts.mel_opts.debug_mel = false;
-
 }
 
-void FsmnVad::ReadModel(const std::string &vad_model) {
+void FsmnVad::LoadConfigFromYaml(const char* filename){
+
+    YAML::Node config;
+    try{
+        config = YAML::LoadFile(filename);
+    }catch(exception const &e){
+        LOG(ERROR) << "Error loading file, yaml file error or not exist.";
+        exit(-1);
+    }
+
+    try{
+        YAML::Node frontend_conf = config["frontend_conf"];
+        YAML::Node post_conf = config["vad_post_conf"];
+
+        this->vad_sample_rate_ = frontend_conf["fs"].as<int>();
+        this->vad_silence_duration_ =  post_conf["max_end_silence_time"].as<int>();
+        this->vad_max_len_ = post_conf["max_single_segment_time"].as<int>();
+        this->vad_speech_noise_thres_ = post_conf["speech_noise_thres"].as<double>();
+
+        fbank_opts.frame_opts.dither = frontend_conf["dither"].as<float>();
+        fbank_opts.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
+        fbank_opts.frame_opts.samp_freq = (float)vad_sample_rate_;
+        fbank_opts.frame_opts.window_type = frontend_conf["window"].as<string>();
+        fbank_opts.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
+        fbank_opts.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
+        fbank_opts.energy_floor = 0;
+        fbank_opts.mel_opts.debug_mel = false;
+    }catch(exception const &e){
+        LOG(ERROR) << "Error when load argument from vad config YAML.";
+        exit(-1);
+    }
+}
+
+void FsmnVad::ReadModel(const char* vad_model) {
     try {
         vad_session_ = std::make_shared<Ort::Session>(
-                env_, vad_model.c_str(), session_options_);
+                env_, vad_model, session_options_);
     } catch (std::exception const &e) {
-        //LOG(ERROR) << "Error when load onnx model: " << e.what();
+        LOG(ERROR) << "Error when load vad onnx model: " << e.what();
         exit(0);
     }
-    //LOG(INFO) << "vad onnx:";
     GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_);
 }
 
@@ -119,13 +139,12 @@
     // 4. Onnx infer
     std::vector<Ort::Value> vad_ort_outputs;
     try {
-        // VLOG(3) << "Start infer";
         vad_ort_outputs = vad_session_->Run(
                 Ort::RunOptions{nullptr}, vad_in_names_.data(), vad_inputs.data(),
                 vad_inputs.size(), vad_out_names_.data(), vad_out_names_.size());
     } catch (std::exception const &e) {
-        // LOG(ERROR) << e.what();
-        return;
+        LOG(ERROR) << "Error when run vad onnx forword: " << (e.what());
+        exit(0);
     }
 
     // 5. Change infer result to output shapes
@@ -163,40 +182,49 @@
 
 void FsmnVad::LoadCmvn(const char *filename)
 {
-    using namespace std;
-    ifstream cmvn_stream(filename);
-    string line;
+    try{
+        using namespace std;
+        ifstream cmvn_stream(filename);
+        if (!cmvn_stream.is_open()) {
+            LOG(ERROR) << "Failed to open file: " << filename;
+            exit(0);
+        }
+        string line;
 
-    while (getline(cmvn_stream, line)) {
-        istringstream iss(line);
-        vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
-        if (line_item[0] == "<AddShift>") {
-            getline(cmvn_stream, line);
-            istringstream means_lines_stream(line);
-            vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
-            if (means_lines[0] == "<LearnRateCoef>") {
-                for (int j = 3; j < means_lines.size() - 1; j++) {
-                    means_list.push_back(stof(means_lines[j]));
+        while (getline(cmvn_stream, line)) {
+            istringstream iss(line);
+            vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
+            if (line_item[0] == "<AddShift>") {
+                getline(cmvn_stream, line);
+                istringstream means_lines_stream(line);
+                vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
+                if (means_lines[0] == "<LearnRateCoef>") {
+                    for (int j = 3; j < means_lines.size() - 1; j++) {
+                        means_list.push_back(stof(means_lines[j]));
+                    }
+                    continue;
                 }
-                continue;
+            }
+            else if (line_item[0] == "<Rescale>") {
+                getline(cmvn_stream, line);
+                istringstream vars_lines_stream(line);
+                vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
+                if (vars_lines[0] == "<LearnRateCoef>") {
+                    for (int j = 3; j < vars_lines.size() - 1; j++) {
+                        // vars_list.push_back(stof(vars_lines[j])*scale);
+                        vars_list.push_back(stof(vars_lines[j]));
+                    }
+                    continue;
+                }
             }
         }
-        else if (line_item[0] == "<Rescale>") {
-            getline(cmvn_stream, line);
-            istringstream vars_lines_stream(line);
-            vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
-            if (vars_lines[0] == "<LearnRateCoef>") {
-                for (int j = 3; j < vars_lines.size() - 1; j++) {
-                    // vars_list.push_back(stof(vars_lines[j])*scale);
-                    vars_list.push_back(stof(vars_lines[j]));
-                }
-                continue;
-            }
-        }
+    }catch(std::exception const &e) {
+        LOG(ERROR) << "Error when load vad cmvn : " << e.what();
+        exit(0);
     }
 }
 
-std::vector<std::vector<float>> &FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats, int lfr_m, int lfr_n) {
+std::vector<std::vector<float>> &FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats) {
 
     std::vector<std::vector<float>> out_feats;
     int T = vad_feats.size();
@@ -243,7 +271,7 @@
     std::vector<std::vector<float>> vad_feats;
     std::vector<std::vector<float>> vad_probs;
     FbankKaldi(vad_sample_rate_, vad_feats, waves);
-    vad_feats = LfrCmvn(vad_feats, 5, 1);
+    vad_feats = LfrCmvn(vad_feats);
     Forward(vad_feats, &vad_probs);
 
     E2EVadModel vad_scorer = E2EVadModel();
@@ -251,7 +279,6 @@
     vad_segments = vad_scorer(vad_probs, waves, true, false, vad_silence_duration_, vad_max_len_,
                               vad_speech_noise_thres_, vad_sample_rate_);
     return vad_segments;
-
 }
 
 void FsmnVad::InitCache(){

--
Gitblit v1.9.1