From 2304327a3fde3aeca144fcc32bcd9e1905ade46a Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期三, 19 四月 2023 11:33:08 +0800
Subject: [PATCH] fix some variables

---
 funasr/runtime/onnxruntime/src/paraformer_onnx.cpp |  223 ++++++++++++++++++++++++++++++++++++-------------------
 1 files changed, 144 insertions(+), 79 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index 695e0f7..a0dd6d4 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -18,10 +18,16 @@
     cmvn_path = pathAppend(path, "am.mvn");
     config_path = pathAppend(path, "config.yaml");
 
-    fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
-    fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
-    memset(fft_input, 0, sizeof(float) * fft_size);
-    plan = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
+    // knf options
+    fbank_opts.frame_opts.dither = 0;
+    fbank_opts.mel_opts.num_bins = 80;
+    fbank_opts.frame_opts.samp_freq = model_sample_rate;
+    fbank_opts.frame_opts.window_type = "hamming";
+    fbank_opts.frame_opts.frame_shift_ms = 10;
+    fbank_opts.frame_opts.frame_length_ms = 25;
+    fbank_opts.energy_floor = 0;
+    fbank_opts.mel_opts.debug_mel = false;
+    //fbank_ = std::make_unique<knf::OnlineFbank>(fbank_opts);
 
     //sessionOptions.SetInterOpNumThreads(1);
     sessionOptions.SetIntraOpNumThreads(nNumThread);
@@ -57,38 +63,28 @@
 {
     if(vocab)
         delete vocab;
-    fftwf_free(fft_input);
-    fftwf_free(fft_out);
-    fftwf_destroy_plan(plan);
-    fftwf_cleanup();
 }
 
 void ModelImp::reset()
 {
 }
 
-void ModelImp::apply_lfr(Tensor<float>*& din)
-{
-    int mm = din->size[2];
-    int ll = ceil(mm / 6.0);
-    Tensor<float>* tmp = new Tensor<float>(ll, 560);
-    int out_offset = 0;
-    for (int i = 0; i < ll; i++) {
-        for (int j = 0; j < 7; j++) {
-            int idx = i * 6 + j - 3;
-            if (idx < 0) {
-                idx = 0;
-            }
-            if (idx >= mm) {
-                idx = mm - 1;
-            }
-            memcpy(tmp->buff + out_offset, din->buff + idx * 80,
-                sizeof(float) * 80);
-            out_offset += 80;
-        }
+vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) {
+    knf::OnlineFbank fbank_(fbank_opts);
+    fbank_.AcceptWaveform(sample_rate, waves, len);
+    //fbank_->InputFinished();
+    int32_t frames = fbank_.NumFramesReady();
+    int32_t feature_dim = fbank_opts.mel_opts.num_bins;
+    vector<float> features(frames * feature_dim);
+    float *p = features.data();
+
+    for (int32_t i = 0; i != frames; ++i) {
+        const float *f = fbank_.GetFrame(i);
+        std::copy(f, f + feature_dim, p);
+        p += feature_dim;
     }
-    delete din;
-    din = tmp;
+
+    return features;
 }
 
 void ModelImp::load_cmvn(const char *filename)
@@ -124,24 +120,6 @@
     }
 }
 
-void ModelImp::apply_cmvn(Tensor<float>* din)
-{
-    const float* var;
-    const float* mean;
-    var = vars_list.data();
-    mean= means_list.data();
-
-    int m = din->size[2];
-    int n = din->size[3];
-
-    for (int i = 0; i < m; i++) {
-        for (int j = 0; j < n; j++) {
-            int idx = i * n + j;
-            din->buff[idx] = (din->buff[idx] + mean[j]) * var[j];
-        }
-    }
-}
-
 string ModelImp::greedy_search(float * in, int nLen )
 {
     vector<int> hyps;
@@ -156,16 +134,115 @@
     return vocab->vector2stringV2(hyps);
 }
 
+vector<float> ModelImp::ApplyLFR(const std::vector<float> &in) 
+{
+    int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
+    int32_t in_num_frames = in.size() / in_feat_dim;
+    int32_t out_num_frames =
+        (in_num_frames - lfr_window_size) / lfr_window_shift + 1;
+    int32_t out_feat_dim = in_feat_dim * lfr_window_size;
+
+    std::vector<float> out(out_num_frames * out_feat_dim);
+
+    const float *p_in = in.data();
+    float *p_out = out.data();
+
+    for (int32_t i = 0; i != out_num_frames; ++i) {
+      std::copy(p_in, p_in + out_feat_dim, p_out);
+
+      p_out += out_feat_dim;
+      p_in += lfr_window_shift * in_feat_dim;
+    }
+
+    return out;
+  }
+
+  void ModelImp::ApplyCMVN(std::vector<float> *v)
+  {
+    int32_t dim = means_list.size();
+    int32_t num_frames = v->size() / dim;
+
+    float *p = v->data();
+
+    for (int32_t i = 0; i != num_frames; ++i) {
+      for (int32_t k = 0; k != dim; ++k) {
+        p[k] = (p[k] + means_list[k]) * vars_list[k];
+      }
+
+      p += dim;
+    }
+  }
+
+//   void ParaformerOnnxAsrModel::ForwardFunc(
+//     const std::vector<std::vector<float>>& chunk_feats,
+//     std::vector<std::vector<float>>* out_prob) {
+//   Ort::MemoryInfo memory_info =
+//       Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+//   // 1. Prepare onnx required data, splice cached_feature_ and chunk_feats
+//   // chunk
+// //  int num_frames = cached_feature_.size() + chunk_feats.size();
+//   int num_frames = chunk_feats.size();
+//   const int feature_dim = chunk_feats[0].size();
+
+//   //  2. Generate 2 input nodes tensor
+//   // speech node { batch,frame number,feature dim }
+//   const int64_t paraformer_feats_shape[3] = {1, num_frames, feature_dim};
+//   std::vector<float> paraformer_feats;
+//   for (const auto & chunk_feat : chunk_feats) {
+//     paraformer_feats.insert(paraformer_feats.end(), chunk_feat.begin(), chunk_feat.end());
+//   }
+//   Ort::Value paraformer_feats_ort = Ort::Value::CreateTensor<float>(
+//           memory_info, paraformer_feats.data(), paraformer_feats.size(), paraformer_feats_shape, 3);
+//   // speech_lengths node {speech length,}
+//   const int64_t paraformer_length_shape[1] = {1};
+//   std::vector<int32_t> paraformer_length;
+//   paraformer_length.emplace_back(num_frames);
+//   Ort::Value paraformer_length_ort = Ort::Value::CreateTensor<int32_t>(
+//           memory_info, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1);
+
+//   // 3. Put nodes into onnx input vector
+//   std::vector<Ort::Value> paraformer_inputs;
+//   paraformer_inputs.emplace_back(std::move(paraformer_feats_ort));
+//   paraformer_inputs.emplace_back(std::move(paraformer_length_ort));
+
+//   // 4. Onnx infer
+//   std::vector<Ort::Value> paraformer_ort_outputs;
+//   try{
+//     VLOG(3) << "Start infer";
+//     paraformer_ort_outputs = paraformer_session_->Run(
+//             Ort::RunOptions{nullptr}, paraformer_in_names_.data(), paraformer_inputs.data(),
+//             paraformer_inputs.size(), paraformer_out_names_.data(), paraformer_out_names_.size());
+//   }catch (std::exception const& e) {
+//     //  Catch "Non-zero status code returned error",usually because there is no asr result.
+//     // Need funasr to resolve.
+//     LOG(ERROR) << e.what();
+//     return;
+//   }
+
+//   // 5. Change infer result to output shapes
+//   float* logp_data = paraformer_ort_outputs[0].GetTensorMutableData<float>();
+//   auto type_info = paraformer_ort_outputs[0].GetTensorTypeAndShapeInfo();
+
+//   int num_outputs = type_info.GetShape()[1];
+//   int output_dim = type_info.GetShape()[2];
+//   out_prob->resize(num_outputs);
+//   for (int i = 0; i < num_outputs; i++) {
+//     (*out_prob)[i].resize(output_dim);
+//     memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
+//            sizeof(float) * output_dim);
+//   }
+// }
+
 string ModelImp::forward(float* din, int len, int flag)
 {
-    Tensor<float>* in;
-    FeatureExtract* fe = new FeatureExtract(3);
-    fe->reset();
-    fe->insert(plan, din, len, flag);
-    fe->fetch(in);
-    apply_lfr(in);
-    apply_cmvn(in);
-    Ort::RunOptions run_option;
+
+    int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
+    std::vector<float> wav_feats = FbankKaldi(model_sample_rate, din, len);
+    wav_feats = ApplyLFR(wav_feats);
+    ApplyCMVN(&wav_feats);
+
+    int32_t feat_dim = lfr_window_size*in_feat_dim;
+    int32_t num_frames = wav_feats.size() / feat_dim;
 
 #ifdef _WIN_X86
         Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
@@ -173,29 +250,26 @@
         Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
 #endif
 
-    std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
+    const int64_t input_shape_[3] = {1, num_frames, feat_dim};
     Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
-        in->buff,
-        in->buff_size,
-        input_shape_.data(),
-        input_shape_.size());
+        wav_feats.data(),
+        wav_feats.size(),
+        input_shape_,
+        3);
 
-    std::vector<int32_t> feats_len{ in->size[2] };
-    std::vector<int64_t> feats_len_dim{ 1 };
-    Ort::Value onnx_feats_len = Ort::Value::CreateTensor(
-        m_memoryInfo,
-        feats_len.data(),
-        feats_len.size() * sizeof(int32_t),
-        feats_len_dim.data(),
-        feats_len_dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
+    const int64_t paraformer_length_shape[1] = {1};
+    std::vector<int32_t> paraformer_length;
+    paraformer_length.emplace_back(num_frames);
+    Ort::Value onnx_feats_len = Ort::Value::CreateTensor<int32_t>(
+          m_memoryInfo, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1);
+    
     std::vector<Ort::Value> input_onnx;
     input_onnx.emplace_back(std::move(onnx_feats));
     input_onnx.emplace_back(std::move(onnx_feats_len));
 
     string result;
     try {
-
-        auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
+        auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
         std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
 
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
@@ -206,15 +280,6 @@
     catch (...)
     {
         result = "";
-    }
-
-    if(in){
-        delete in;
-        in = nullptr;
-    }
-    if(fe){
-        delete fe;
-        fe = nullptr;
     }
 
     return result;

--
Gitblit v1.9.1