From 6ef5ccc784d9f7fcd4072dd83c6ceaf1c324b92c Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期二, 25 四月 2023 20:34:53 +0800
Subject: [PATCH] check am.mvn exists
---
funasr/runtime/onnxruntime/src/fsmn-vad.cpp | 79 +++++++++++++++++++++------------------
1 files changed, 43 insertions(+), 36 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
index 0f87cb2..7360a9a 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -1,8 +1,10 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
#include <fstream>
#include "precomp.h"
-//#include "glog/logging.h"
-
void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
float vad_speech_noise_thres) {
@@ -34,10 +36,10 @@
vad_session_ = std::make_shared<Ort::Session>(
env_, vad_model.c_str(), session_options_);
} catch (std::exception const &e) {
- //LOG(ERROR) << "Error when load onnx model: " << e.what();
+ LOG(ERROR) << "Error when load vad onnx model: " << e.what();
exit(0);
}
- //LOG(INFO) << "vad onnx:";
+ LOG(INFO) << "vad onnx:";
GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_);
}
@@ -59,8 +61,8 @@
shape << j;
shape << " ";
}
- // LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type
- // << " dims=" << shape.str();
+ LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type
+ << " dims=" << shape.str();
(*in_names)[i] = name.get();
name.release();
}
@@ -78,8 +80,8 @@
shape << j;
shape << " ";
}
- // LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type
- // << " dims=" << shape.str();
+ LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type
+ << " dims=" << shape.str();
(*out_names)[i] = name.get();
name.release();
}
@@ -119,12 +121,12 @@
// 4. Onnx infer
std::vector<Ort::Value> vad_ort_outputs;
try {
- // VLOG(3) << "Start infer";
+ VLOG(3) << "Start infer";
vad_ort_outputs = vad_session_->Run(
Ort::RunOptions{nullptr}, vad_in_names_.data(), vad_inputs.data(),
vad_inputs.size(), vad_out_names_.data(), vad_out_names_.size());
} catch (std::exception const &e) {
- // LOG(ERROR) << e.what();
+ LOG(ERROR) << e.what();
return;
}
@@ -163,36 +165,41 @@
void FsmnVad::LoadCmvn(const char *filename)
{
- using namespace std;
- ifstream cmvn_stream(filename);
- string line;
+ try{
+ using namespace std;
+ ifstream cmvn_stream(filename);
+ string line;
- while (getline(cmvn_stream, line)) {
- istringstream iss(line);
- vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
- if (line_item[0] == "<AddShift>") {
- getline(cmvn_stream, line);
- istringstream means_lines_stream(line);
- vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
- if (means_lines[0] == "<LearnRateCoef>") {
- for (int j = 3; j < means_lines.size() - 1; j++) {
- means_list.push_back(stof(means_lines[j]));
+ while (getline(cmvn_stream, line)) {
+ istringstream iss(line);
+ vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
+ if (line_item[0] == "<AddShift>") {
+ getline(cmvn_stream, line);
+ istringstream means_lines_stream(line);
+ vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
+ if (means_lines[0] == "<LearnRateCoef>") {
+ for (int j = 3; j < means_lines.size() - 1; j++) {
+ means_list.push_back(stof(means_lines[j]));
+ }
+ continue;
}
- continue;
+ }
+ else if (line_item[0] == "<Rescale>") {
+ getline(cmvn_stream, line);
+ istringstream vars_lines_stream(line);
+ vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
+ if (vars_lines[0] == "<LearnRateCoef>") {
+ for (int j = 3; j < vars_lines.size() - 1; j++) {
+ // vars_list.push_back(stof(vars_lines[j])*scale);
+ vars_list.push_back(stof(vars_lines[j]));
+ }
+ continue;
+ }
}
}
- else if (line_item[0] == "<Rescale>") {
- getline(cmvn_stream, line);
- istringstream vars_lines_stream(line);
- vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
- if (vars_lines[0] == "<LearnRateCoef>") {
- for (int j = 3; j < vars_lines.size() - 1; j++) {
- // vars_list.push_back(stof(vars_lines[j])*scale);
- vars_list.push_back(stof(vars_lines[j]));
- }
- continue;
- }
- }
+ }catch(std::exception const &e) {
+ LOG(ERROR) << "Error when load vad cmvn : " << e.what();
+ exit(0);
}
}
--
Gitblit v1.9.1