lyblsgo
2023-04-25 1d205d340ff5129e457fa462eb5b31b152086339
funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -1,36 +1,70 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#include "precomp.h"
using namespace std;
using namespace paraformer;
Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc)
Paraformer::Paraformer(std::map<std::string, std::string>& model_path,int thread_num)
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
    string model_path;
    string cmvn_path;
    string config_path;
    // VAD model
    if(use_vad){
        string vad_path = PathAppend(path, "vad_model.onnx");
        string mvn_path = PathAppend(path, "vad.mvn");
    if(model_path.find(VAD_MODEL_PATH) != model_path.end()){
        use_vad = true;
        string vad_model_path;
        string vad_cmvn_path;
        try{
            vad_model_path = model_path.at(VAD_MODEL_PATH);
            vad_cmvn_path = model_path.at(VAD_CMVN_PATH);
        }catch(const out_of_range& e){
            LOG(ERROR) << "Error when read "<< VAD_CMVN_PATH <<" :" << e.what();
            exit(0);
        }
        vad_handle = make_unique<FsmnVad>();
        vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
        vad_handle->InitVad(vad_model_path, vad_cmvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
    }
    // AM model
    if(model_path.find(AM_MODEL_PATH) != model_path.end()){
        string am_model_path;
        string am_cmvn_path;
        string am_config_path;
        try{
            am_model_path = model_path.at(AM_MODEL_PATH);
            am_cmvn_path = model_path.at(AM_CMVN_PATH);
            am_config_path = model_path.at(AM_CONFIG_PATH);
        }catch(const out_of_range& e){
            LOG(ERROR) << "Error when read "<< AM_CONFIG_PATH << " or " << AM_CMVN_PATH <<" :" << e.what();
            exit(0);
        }
        InitAM(am_model_path, am_cmvn_path, am_config_path, thread_num);
    }
    // PUNC model
    if(use_punc){
        punc_handle = make_unique<CTTransformer>(path, thread_num);
    }
    if(model_path.find(PUNC_MODEL_PATH) != model_path.end()){
        use_punc = true;
        string punc_model_path;
        string punc_config_path;
        try{
            punc_model_path = model_path.at(PUNC_MODEL_PATH);
            punc_config_path = model_path.at(PUNC_CONFIG_PATH);
        }catch(const out_of_range& e){
            LOG(ERROR) << "Error when read "<< PUNC_CONFIG_PATH <<" :" << e.what();
            exit(0);
        }
    if(quantize)
    {
        model_path = PathAppend(path, "model_quant.onnx");
    }else{
        model_path = PathAppend(path, "model.onnx");
        punc_handle = make_unique<CTTransformer>();
        punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
    }
    cmvn_path = PathAppend(path, "am.mvn");
    config_path = PathAppend(path, "config.yaml");
}
void Paraformer::InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
    // knf options
    fbank_opts.frame_opts.dither = 0;
    fbank_opts.mel_opts.num_bins = 80;
@@ -48,12 +82,12 @@
    // DisableCpuMemArena can improve performance
    session_options.DisableCpuMemArena();
#ifdef _WIN32
    wstring wstrPath = strToWstr(model_path);
    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
#else
    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
#endif
    try {
        m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load am onnx model: " << e.what();
        exit(0);
    }
    string strName;
    GetInputName(m_session.get(), strName);
@@ -70,8 +104,8 @@
        m_szInputNames.push_back(item.c_str());
    for (auto& item : m_strOutputNames)
        m_szOutputNames.push_back(item.c_str());
    vocab = new Vocab(config_path.c_str());
    LoadCmvn(cmvn_path.c_str());
    vocab = new Vocab(am_config.c_str());
    LoadCmvn(am_cmvn.c_str());
}
Paraformer::~Paraformer()