From 55708e7cebaedefc5f69d61f157993da41848b8f Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期日, 23 四月 2023 19:06:25 +0800
Subject: [PATCH] add offline punc for onnxruntime

---
 funasr/runtime/onnxruntime/src/paraformer_onnx.cpp |   20 +++++++++++++++-----
 1 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index 1e4a310..69d1554 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -14,7 +14,12 @@
         string vad_path = pathAppend(path, "vad_model.onnx");
         string mvn_path = pathAppend(path, "vad.mvn");
         vadHandle = make_unique<FsmnVad>();
-        vadHandle->init_vad(vad_path, mvn_path, model_sample_rate, 800, 15000, 0.9);
+        vadHandle->init_vad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
+    }
+
+    // PUNC model
+    if(true){
+        puncHandle = make_unique<CTTransformer>(path, nNumThread);
     }
 
     if(quantize)
@@ -29,7 +34,7 @@
     // knf options
     fbank_opts.frame_opts.dither = 0;
     fbank_opts.mel_opts.num_bins = 80;
-    fbank_opts.frame_opts.samp_freq = model_sample_rate;
+    fbank_opts.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
     fbank_opts.frame_opts.window_type = "hamming";
     fbank_opts.frame_opts.frame_shift_ms = 10;
     fbank_opts.frame_opts.frame_length_ms = 25;
@@ -50,6 +55,7 @@
     m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
 #endif
 
+    vector<string> m_strInputNames, m_strOutputNames;
     string strName;
     getInputName(m_session.get(), strName);
     m_strInputNames.push_back(strName.c_str());
@@ -81,6 +87,10 @@
 
 vector<std::vector<int>> ModelImp::vad_seg(std::vector<float>& pcm_data){
     return vadHandle->infer(pcm_data);
+}
+
+string ModelImp::AddPunc(const char* szInput){
+    return puncHandle->AddPunc(szInput);
 }
 
 vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) {
@@ -191,7 +201,7 @@
 {
 
     int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
-    std::vector<float> wav_feats = FbankKaldi(model_sample_rate, din, len);
+    std::vector<float> wav_feats = FbankKaldi(MODEL_SAMPLE_RATE, din, len);
     wav_feats = ApplyLFR(wav_feats);
     ApplyCMVN(&wav_feats);
 
@@ -231,9 +241,9 @@
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
         result = greedy_search(floatData, *encoder_out_lens);
     }
-    catch (...)
+    catch (std::exception const &e)
     {
-        result = "";
+        printf(e.what());
     }
 
     return result;

--
Gitblit v1.9.1