From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/runtime/onnxruntime/src/paraformer-online.cpp |   52 ++++++++++++++++++++++++++++------------------------
 1 files changed, 28 insertions(+), 24 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/paraformer-online.cpp b/funasr/runtime/onnxruntime/src/paraformer-online.cpp
index 1787f02..ed7a35a 100644
--- a/funasr/runtime/onnxruntime/src/paraformer-online.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer-online.cpp
@@ -101,35 +101,39 @@
         waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
         }
         if (lfr_splice_cache_.empty()) {
-        for (int i = 0; i < (lfr_m - 1) / 2; i++) {
-            lfr_splice_cache_.emplace_back(wav_feats[0]);
-        }
+            for (int i = 0; i < (lfr_m - 1) / 2; i++) {
+                lfr_splice_cache_.emplace_back(wav_feats[0]);
+            }
         }
         if (wav_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
-        wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
-        int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
-        int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
-        int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
-        int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
-        reserve_waveforms_.clear();
-        reserve_waveforms_.insert(reserve_waveforms_.begin(),
-                                    waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
-                                    waves.begin() + frame_from_waves * frame_shift_sample_length_);
-        int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
-        waves.erase(waves.begin() + sample_length, waves.end());
+            wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
+            int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
+            int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
+            int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
+            int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
+            reserve_waveforms_.clear();
+            reserve_waveforms_.insert(reserve_waveforms_.begin(),
+                                        waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
+                                        waves.begin() + frame_from_waves * frame_shift_sample_length_);
+            int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
+            waves.erase(waves.begin() + sample_length, waves.end());
         } else {
-        reserve_waveforms_.clear();
-        reserve_waveforms_.insert(reserve_waveforms_.begin(),
-                                    waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
-        lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
+            reserve_waveforms_.clear();
+            reserve_waveforms_.insert(reserve_waveforms_.begin(),
+                                        waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
+            lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
         }
     } else {
         if (input_finished) {
-        if (!reserve_waveforms_.empty()) {
-            waves = reserve_waveforms_;
-        }
-        wav_feats = lfr_splice_cache_;
-        OnlineLfrCmvn(wav_feats, input_finished);
+            if (!reserve_waveforms_.empty()) {
+                waves = reserve_waveforms_;
+            }
+            wav_feats = lfr_splice_cache_;
+            if(wav_feats.size() == 0){
+                LOG(ERROR) << "wav_feats's size is 0";
+            }else{
+                OnlineLfrCmvn(wav_feats, input_finished);
+            }
         }
     }
     if(input_finished){
@@ -465,7 +469,7 @@
     return result;
 }
 
-string ParaformerOnline::Forward(float* din, int len, bool input_finished)
+string ParaformerOnline::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb)
 {
     std::vector<std::vector<float>> wav_feats;
     std::vector<float> waves(din, din+len);

--
Gitblit v1.9.1