lyblsgo
2023-04-24 6406616c2d3fd5eae0201cd4680258f7c4e937dd
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -3,7 +3,7 @@
using namespace std;
using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad)
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad, bool use_punc)
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
    string model_path;
    string cmvn_path;
@@ -14,7 +14,12 @@
        string vad_path = pathAppend(path, "vad_model.onnx");
        string mvn_path = pathAppend(path, "vad.mvn");
        vadHandle = make_unique<FsmnVad>();
        vadHandle->init_vad(vad_path, mvn_path, model_sample_rate, 800, 15000, 0.9);
        vadHandle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
    }
    // PUNC model
    if(use_punc){
        puncHandle = make_unique<CTTransformer>(path, nNumThread);
    }
    if(quantize)
@@ -29,7 +34,7 @@
    // knf options
    fbank_opts.frame_opts.dither = 0;
    fbank_opts.mel_opts.num_bins = 80;
    fbank_opts.frame_opts.samp_freq = model_sample_rate;
    fbank_opts.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
    fbank_opts.frame_opts.window_type = "hamming";
    fbank_opts.frame_opts.frame_shift_ms = 10;
    fbank_opts.frame_opts.frame_length_ms = 25;
@@ -80,7 +85,11 @@
}
vector<std::vector<int>> ModelImp::vad_seg(std::vector<float>& pcm_data){
    return vadHandle->infer(pcm_data);
    return vadHandle->Infer(pcm_data);
}
string ModelImp::AddPunc(const char* szInput){
    return puncHandle->AddPunc(szInput);
}
vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) {
@@ -191,7 +200,7 @@
{
    int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
    std::vector<float> wav_feats = FbankKaldi(model_sample_rate, din, len);
    std::vector<float> wav_feats = FbankKaldi(MODEL_SAMPLE_RATE, din, len);
    wav_feats = ApplyLFR(wav_feats);
    ApplyCMVN(&wav_feats);
@@ -231,9 +240,9 @@
        auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
        result = greedy_search(floatData, *encoder_out_lens);
    }
    catch (...)
    catch (std::exception const &e)
    {
        result = "";
        printf(e.what());
    }
    return result;