jmwang66
2023-05-09 8dab6d184a034ca86eafa644ea0d2100aadfe27d
funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -1,36 +1,19 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#include "precomp.h"
using namespace std;
using namespace paraformer;
Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc)
namespace funasr {
Paraformer::Paraformer()
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
    string model_path;
    string cmvn_path;
    string config_path;
}
    // VAD model
    if(use_vad){
        string vad_path = PathAppend(path, "vad_model.onnx");
        string mvn_path = PathAppend(path, "vad.mvn");
        vad_handle = make_unique<FsmnVad>();
        vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
    }
    // PUNC model
    if(use_punc){
        punc_handle = make_unique<CTTransformer>(path, thread_num);
    }
    if(quantize)
    {
        model_path = PathAppend(path, "model_quant.onnx");
    }else{
        model_path = PathAppend(path, "model.onnx");
    }
    cmvn_path = PathAppend(path, "am.mvn");
    config_path = PathAppend(path, "config.yaml");
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
    // knf options
    fbank_opts.frame_opts.dither = 0;
    fbank_opts.mel_opts.num_bins = 80;
@@ -48,12 +31,12 @@
    // DisableCpuMemArena can improve performance
    session_options.DisableCpuMemArena();
#ifdef _WIN32
    wstring wstrPath = strToWstr(model_path);
    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
#else
    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
#endif
    try {
        m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load am onnx model: " << e.what();
        exit(0);
    }
    string strName;
    GetInputName(m_session.get(), strName);
@@ -70,8 +53,8 @@
        m_szInputNames.push_back(item.c_str());
    for (auto& item : m_strOutputNames)
        m_szOutputNames.push_back(item.c_str());
    vocab = new Vocab(config_path.c_str());
    LoadCmvn(cmvn_path.c_str());
    vocab = new Vocab(am_config.c_str());
    LoadCmvn(am_cmvn.c_str());
}
Paraformer::~Paraformer()
@@ -82,14 +65,6 @@
void Paraformer::Reset()
{
}
vector<std::vector<int>> Paraformer::VadSeg(std::vector<float>& pcm_data){
    return vad_handle->Infer(pcm_data);
}
string Paraformer::AddPunc(const char* sz_input){
    return punc_handle->AddPunc(sz_input);
}
vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
@@ -113,6 +88,10 @@
void Paraformer::LoadCmvn(const char *filename)
{
    ifstream cmvn_stream(filename);
    if (!cmvn_stream.is_open()) {
        LOG(ERROR) << "Failed to open file: " << filename;
        exit(0);
    }
    string line;
    while (getline(cmvn_stream, line)) {
@@ -143,14 +122,14 @@
    }
}
string Paraformer::GreedySearch(float * in, int n_len )
string Paraformer::GreedySearch(float * in, int n_len,  int64_t token_nums)
{
    vector<int> hyps;
    int Tmax = n_len;
    for (int i = 0; i < Tmax; i++) {
        int max_idx;
        float max_val;
        FindMax(in + i * 8404, 8404, max_val, max_idx);
        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
        hyps.push_back(max_idx);
    }
@@ -238,11 +217,11 @@
        int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
        float* floatData = outputTensor[0].GetTensorMutableData<float>();
        auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
        result = GreedySearch(floatData, *encoder_out_lens);
        result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
    }
    catch (std::exception const &e)
    {
        printf(e.what());
        LOG(ERROR)<<e.what();
    }
    return result;
@@ -251,12 +230,13 @@
string Paraformer::ForwardChunk(float* din, int len, int flag)
{
    printf("Not Imp!!!!!!\n");
    return "Hello";
    LOG(ERROR)<<"Not Imp!!!!!!";
    return "";
}
string Paraformer::Rescoring()
{
    printf("Not Imp!!!!!!\n");
    return "Hello";
    LOG(ERROR)<<"Not Imp!!!!!!";
    return "";
}
} // namespace funasr