hnluo
2023-04-17 24f73665e2d8ea8e4de2fe4f900bc539d7f7b989
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -4,18 +4,24 @@
using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
{
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
    string model_path;
    string vocab_path;
    string cmvn_path;
    string config_path;
    if(quantize)
    {
        model_path = pathAppend(path, "model_quant.onnx");
    }else{
        model_path = pathAppend(path, "model.onnx");
    }
    vocab_path = pathAppend(path, "vocab.txt");
    cmvn_path = pathAppend(path, "am.mvn");
    config_path = pathAppend(path, "config.yaml");
    fe = new FeatureExtract(3);
    fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
    fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
    memset(fft_input, 0, sizeof(float) * fft_size);
    plan = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
    //sessionOptions.SetInterOpNumThreads(1);
    sessionOptions.SetIntraOpNumThreads(nNumThread);
@@ -23,45 +29,42 @@
#ifdef _WIN32
    wstring wstrPath = strToWstr(model_path);
    m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#else
    m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#endif
    string strName;
    getInputName(m_session, strName);
    getInputName(m_session.get(), strName);
    m_strInputNames.push_back(strName.c_str());
    getInputName(m_session, strName,1);
    getInputName(m_session.get(), strName,1);
    m_strInputNames.push_back(strName);
    
    getOutputName(m_session, strName);
    getOutputName(m_session.get(), strName);
    m_strOutputNames.push_back(strName);
    getOutputName(m_session, strName,1);
    getOutputName(m_session.get(), strName,1);
    m_strOutputNames.push_back(strName);
    for (auto& item : m_strInputNames)
        m_szInputNames.push_back(item.c_str());
    for (auto& item : m_strOutputNames)
        m_szOutputNames.push_back(item.c_str());
    vocab = new Vocab(vocab_path.c_str());
    vocab = new Vocab(config_path.c_str());
    load_cmvn(cmvn_path.c_str());
}
ModelImp::~ModelImp()
{
    if(fe)
        delete fe;
    if (m_session)
    {
        delete m_session;
        m_session = nullptr;
    }
    if(vocab)
        delete vocab;
    fftwf_free(fft_input);
    fftwf_free(fft_out);
    fftwf_destroy_plan(plan);
    fftwf_cleanup();
}
void ModelImp::reset()
{
    fe->reset();
}
void ModelImp::apply_lfr(Tensor<float>*& din)
@@ -88,16 +91,49 @@
    din = tmp;
}
void ModelImp::load_cmvn(const char *filename)
{
    ifstream cmvn_stream(filename);
    string line;
    while (getline(cmvn_stream, line)) {
        istringstream iss(line);
        vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
        if (line_item[0] == "<AddShift>") {
            getline(cmvn_stream, line);
            istringstream means_lines_stream(line);
            vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
            if (means_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < means_lines.size() - 1; j++) {
                    means_list.push_back(stof(means_lines[j]));
                }
                continue;
            }
        }
        else if (line_item[0] == "<Rescale>") {
            getline(cmvn_stream, line);
            istringstream vars_lines_stream(line);
            vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
            if (vars_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < vars_lines.size() - 1; j++) {
                    vars_list.push_back(stof(vars_lines[j])*scale);
                }
                continue;
            }
        }
    }
}
void ModelImp::apply_cmvn(Tensor<float>* din)
{
    const float* var;
    const float* mean;
    float scale = 22.6274169979695;
    var = vars_list.data();
    mean= means_list.data();
    int m = din->size[2];
    int n = din->size[3];
    var = (const float*)paraformer_cmvn_var_hex;
    mean = (const float*)paraformer_cmvn_mean_hex;
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            int idx = i * n + j;
@@ -122,13 +158,20 @@
string ModelImp::forward(float* din, int len, int flag)
{
    Tensor<float>* in;
    fe->insert(din, len, flag);
    FeatureExtract* fe = new FeatureExtract(3);
    fe->reset();
    fe->insert(plan, din, len, flag);
    fe->fetch(in);
    apply_lfr(in);
    apply_cmvn(in);
    Ort::RunOptions run_option;
#ifdef _WIN_X86
        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
#else
        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
#endif
    std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
    Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
@@ -155,7 +198,6 @@
        auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
        std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
        int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
        float* floatData = outputTensor[0].GetTensorMutableData<float>();
        auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
@@ -166,9 +208,14 @@
        result = "";
    }
    if(in)
    if(in){
        delete in;
        in = nullptr;
    }
    if(fe){
        delete fe;
        fe = nullptr;
    }
    return result;
}