lyblsgo
2023-04-25 1d205d340ff5129e457fa462eb5b31b152086339
funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -1,3 +1,7 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#include <fstream>
#include "precomp.h"
@@ -32,7 +36,7 @@
        vad_session_ = std::make_shared<Ort::Session>(
                env_, vad_model.c_str(), session_options_);
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load onnx model: " << e.what();
        LOG(ERROR) << "Error when load vad onnx model: " << e.what();
        exit(0);
    }
    LOG(INFO) << "vad onnx:";
@@ -161,36 +165,41 @@
void FsmnVad::LoadCmvn(const char *filename)
{
    using namespace std;
    ifstream cmvn_stream(filename);
    string line;
    try{
        using namespace std;
        ifstream cmvn_stream(filename);
        string line;
    while (getline(cmvn_stream, line)) {
        istringstream iss(line);
        vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
        if (line_item[0] == "<AddShift>") {
            getline(cmvn_stream, line);
            istringstream means_lines_stream(line);
            vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
            if (means_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < means_lines.size() - 1; j++) {
                    means_list.push_back(stof(means_lines[j]));
        while (getline(cmvn_stream, line)) {
            istringstream iss(line);
            vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
            if (line_item[0] == "<AddShift>") {
                getline(cmvn_stream, line);
                istringstream means_lines_stream(line);
                vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
                if (means_lines[0] == "<LearnRateCoef>") {
                    for (int j = 3; j < means_lines.size() - 1; j++) {
                        means_list.push_back(stof(means_lines[j]));
                    }
                    continue;
                }
                continue;
            }
            else if (line_item[0] == "<Rescale>") {
                getline(cmvn_stream, line);
                istringstream vars_lines_stream(line);
                vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
                if (vars_lines[0] == "<LearnRateCoef>") {
                    for (int j = 3; j < vars_lines.size() - 1; j++) {
                        // vars_list.push_back(stof(vars_lines[j])*scale);
                        vars_list.push_back(stof(vars_lines[j]));
                    }
                    continue;
                }
            }
        }
        else if (line_item[0] == "<Rescale>") {
            getline(cmvn_stream, line);
            istringstream vars_lines_stream(line);
            vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
            if (vars_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < vars_lines.size() - 1; j++) {
                    // vars_list.push_back(stof(vars_lines[j])*scale);
                    vars_list.push_back(stof(vars_lines[j]));
                }
                continue;
            }
        }
    }catch(std::exception const &e) {
        LOG(ERROR) << "Error when load vad cmvn : " << e.what();
        exit(0);
    }
}