From ade08818b7a579aac75182b906a5bd3b8126411c Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 27 五月 2024 15:46:26 +0800
Subject: [PATCH] Merge branch 'dev_batch' into main

---
 runtime/onnxruntime/src/audio.cpp |  106 +++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 103 insertions(+), 3 deletions(-)

diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp
index 0135ab4..22a9ecd 100644
--- a/runtime/onnxruntime/src/audio.cpp
+++ b/runtime/onnxruntime/src/audio.cpp
@@ -1023,6 +1023,90 @@
     }
 }
 
+int Audio::Fetch(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
+{
+    batch_in = std::min((int)frame_queue.size(), batch_size);
+    if (batch_in == 0){
+        return 0;
+    } else{
+        // init
+        dout = new float*[batch_in];
+        len = new int[batch_in];
+        flag = new int[batch_in];
+        start_time = new float[batch_in];
+
+        for(int idx=0; idx < batch_in; idx++){
+            AudioFrame *frame = frame_queue.front();
+            frame_queue.pop();
+
+            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
+            dout[idx] = speech_data + frame->GetStart();
+            len[idx] = frame->GetLen();
+            delete frame;
+            flag[idx] = S_END;
+        }
+        return 1;
+    }
+}
+
+int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
+{
+    //compute batch size
+    queue<AudioFrame *> frame_batch;
+    int max_acc = 300*1000*seg_sample;
+    int max_sent = 60*1000*seg_sample;
+    int bs_acc = 0;
+    int max_len = 0;
+    int max_batch = 1;
+    #ifdef USE_GPU
+        max_batch = batch_size;
+    #endif
+    max_batch = std::min(max_batch, (int)frame_queue.size());
+
+    for(int idx=0; idx < max_batch; idx++){
+        AudioFrame *frame = frame_queue.front();
+        int length = frame->GetLen();
+        if(length >= max_sent){
+            if(bs_acc == 0){
+                bs_acc++;
+                frame_batch.push(frame);
+                frame_queue.pop();                
+            }
+            break;
+        }
+        max_len = std::max(max_len, frame->GetLen());
+        if(max_len*(bs_acc+1) > max_acc){
+            break;
+        }
+        bs_acc++;
+        frame_batch.push(frame);
+        frame_queue.pop();
+    }
+
+    batch_in = (int)frame_batch.size();
+    if (batch_in == 0){
+        return 0;
+    } else{
+        // init
+        dout = new float*[batch_in];
+        len = new int[batch_in];
+        flag = new int[batch_in];
+        start_time = new float[batch_in];
+
+        for(int idx=0; idx < batch_in; idx++){
+            AudioFrame *frame = frame_batch.front();
+            frame_batch.pop();
+
+            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
+            dout[idx] = speech_data + frame->GetStart();
+            len[idx] = frame->GetLen();
+            delete frame;
+            flag[idx] = S_END;
+        }
+        return 1;
+    }
+}
+
 void Audio::Padding()
 {
     float num_samples = speech_len;
@@ -1085,7 +1169,7 @@
     }
 }
 
-void Audio::CutSplit(OfflineStream* offline_stream)
+void Audio::CutSplit(OfflineStream* offline_stream, std::vector<int> &index_vector)
 {
     std::unique_ptr<VadModel> vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(offline_stream->vad_handle).get());
     AudioFrame *frame;
@@ -1112,6 +1196,7 @@
     }    
 
     int speech_start_i = -1, speech_end_i =-1;
+    std::vector<AudioFrame*> vad_frames;
     for(vector<int> vad_segment:vad_segments)
     {
         if(vad_segment.size() != 2){
@@ -1126,16 +1211,31 @@
         }
 
         if(speech_start_i!=-1 && speech_end_i!=-1){
-            frame = new AudioFrame();
             int start = speech_start_i*seg_sample;
             int end = speech_end_i*seg_sample;
+            frame = new AudioFrame(end-start);
             frame->SetStart(start);
             frame->SetEnd(end);
-            frame_queue.push(frame);
+            vad_frames.push_back(frame);
             frame = nullptr;
             speech_start_i=-1;
             speech_end_i=-1;
         }
+
+    }
+    // sort
+    {
+        index_vector.clear();
+        index_vector.resize(vad_frames.size());
+        for (int i = 0; i < index_vector.size(); ++i) {
+            index_vector[i] = i;
+        }
+        std::sort(index_vector.begin(), index_vector.end(), [&vad_frames](const int a, const int b) {
+            return vad_frames[a]->len < vad_frames[b]->len;
+        });
+        for (int idx : index_vector) {
+            frame_queue.push(vad_frames[idx]);
+        }
     }
 }
 

--
Gitblit v1.9.1