From bc723ea200144bd6fa8a5dff4b9a780feda144fc Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 29 六月 2023 18:55:01 +0800
Subject: [PATCH] dcos

---
 funasr/runtime/onnxruntime/src/paraformer.cpp |   89 ++++++++++++++++++--------------------------
 1 files changed, 37 insertions(+), 52 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index 493dd6d..b605fff 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -1,36 +1,19 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
+
 #include "precomp.h"
 
 using namespace std;
-using namespace paraformer;
 
-Paraformer::Paraformer(const char* path,int thread_num, bool quantize, bool use_vad, bool use_punc)
+namespace funasr {
+
+Paraformer::Paraformer()
 :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
-    string model_path;
-    string cmvn_path;
-    string config_path;
+}
 
-    // VAD model
-    if(use_vad){
-        string vad_path = PathAppend(path, "vad_model.onnx");
-        string mvn_path = PathAppend(path, "vad.mvn");
-        vad_handle = make_unique<FsmnVad>();
-        vad_handle->InitVad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
-    }
-
-    // PUNC model
-    if(use_punc){
-        punc_handle = make_unique<CTTransformer>(path, thread_num);
-    }
-
-    if(quantize)
-    {
-        model_path = PathAppend(path, "model_quant.onnx");
-    }else{
-        model_path = PathAppend(path, "model.onnx");
-    }
-    cmvn_path = PathAppend(path, "am.mvn");
-    config_path = PathAppend(path, "config.yaml");
-
+void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
     // knf options
     fbank_opts.frame_opts.dither = 0;
     fbank_opts.mel_opts.num_bins = 80;
@@ -48,12 +31,13 @@
     // DisableCpuMemArena can improve performance
     session_options.DisableCpuMemArena();
 
-#ifdef _WIN32
-    wstring wstrPath = strToWstr(model_path);
-    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
-#else
-    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), session_options);
-#endif
+    try {
+        m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
+        LOG(INFO) << "Successfully load model from " << am_model;
+    } catch (std::exception const &e) {
+        LOG(ERROR) << "Error when load am onnx model: " << e.what();
+        exit(0);
+    }
 
     string strName;
     GetInputName(m_session.get(), strName);
@@ -70,8 +54,8 @@
         m_szInputNames.push_back(item.c_str());
     for (auto& item : m_strOutputNames)
         m_szOutputNames.push_back(item.c_str());
-    vocab = new Vocab(config_path.c_str());
-    LoadCmvn(cmvn_path.c_str());
+    vocab = new Vocab(am_config.c_str());
+    LoadCmvn(am_cmvn.c_str());
 }
 
 Paraformer::~Paraformer()
@@ -84,17 +68,13 @@
 {
 }
 
-vector<std::vector<int>> Paraformer::VadSeg(std::vector<float>& pcm_data){
-    return vad_handle->Infer(pcm_data);
-}
-
-string Paraformer::AddPunc(const char* sz_input){
-    return punc_handle->AddPunc(sz_input);
-}
-
 vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
     knf::OnlineFbank fbank_(fbank_opts);
-    fbank_.AcceptWaveform(sample_rate, waves, len);
+    std::vector<float> buf(len);
+    for (int32_t i = 0; i != len; ++i) {
+        buf[i] = waves[i] * 32768;
+    }
+    fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
     //fbank_->InputFinished();
     int32_t frames = fbank_.NumFramesReady();
     int32_t feature_dim = fbank_opts.mel_opts.num_bins;
@@ -113,6 +93,10 @@
 void Paraformer::LoadCmvn(const char *filename)
 {
     ifstream cmvn_stream(filename);
+    if (!cmvn_stream.is_open()) {
+        LOG(ERROR) << "Failed to open file: " << filename;
+        exit(0);
+    }
     string line;
 
     while (getline(cmvn_stream, line)) {
@@ -143,14 +127,14 @@
     }
 }
 
-string Paraformer::GreedySearch(float * in, int n_len )
+string Paraformer::GreedySearch(float * in, int n_len,  int64_t token_nums)
 {
     vector<int> hyps;
     int Tmax = n_len;
     for (int i = 0; i < Tmax; i++) {
         int max_idx;
         float max_val;
-        FindMax(in + i * 8404, 8404, max_val, max_idx);
+        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
         hyps.push_back(max_idx);
     }
 
@@ -238,11 +222,11 @@
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
         float* floatData = outputTensor[0].GetTensorMutableData<float>();
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
-        result = GreedySearch(floatData, *encoder_out_lens);
+        result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
     }
     catch (std::exception const &e)
     {
-        printf(e.what());
+        LOG(ERROR)<<e.what();
     }
 
     return result;
@@ -251,12 +235,13 @@
 string Paraformer::ForwardChunk(float* din, int len, int flag)
 {
 
-    printf("Not Imp!!!!!!\n");
-    return "Hello";
+    LOG(ERROR)<<"Not Imp!!!!!!";
+    return "";
 }
 
 string Paraformer::Rescoring()
 {
-    printf("Not Imp!!!!!!\n");
-    return "Hello";
+    LOG(ERROR)<<"Not Imp!!!!!!";
+    return "";
 }
+} // namespace funasr

--
Gitblit v1.9.1