From 63a70a00f7c9f162e8d7b3e330516438fb8cd87b Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 19 四月 2024 16:08:05 +0800
Subject: [PATCH] add dynamic batch for fetch

---
 runtime/onnxruntime/src/audio.cpp         |   57 ++++++++++++++++++++++++++++
 runtime/onnxruntime/src/funasrruntime.cpp |    4 +-
 runtime/onnxruntime/include/audio.h       |    1 
 3 files changed, 60 insertions(+), 2 deletions(-)

diff --git a/runtime/onnxruntime/include/audio.h b/runtime/onnxruntime/include/audio.h
index ba5f9e0..3011050 100644
--- a/runtime/onnxruntime/include/audio.h
+++ b/runtime/onnxruntime/include/audio.h
@@ -84,6 +84,7 @@
     int Fetch(float *&dout, int &len, int &flag);
     int Fetch(float *&dout, int &len, int &flag, float &start_time);
     int Fetch(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
+    int FetchDynamic(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
     void Padding();
     void Split(OfflineStream* offline_streamj);
     void CutSplit(OfflineStream* offline_streamj, std::vector<int> &index_vector);
diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp
index f1285af..a5a44ca 100644
--- a/runtime/onnxruntime/src/audio.cpp
+++ b/runtime/onnxruntime/src/audio.cpp
@@ -1049,6 +1049,63 @@
     }
 }
 
+int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
+{
+    //compute batch size
+    queue<AudioFrame *> frame_batch;
+    int max_acc = 300*1000*seg_sample;
+    int max_sent = 60*1000*seg_sample;
+    int bs_acc = 0;
+    int max_len = 0;
+    int max_batch = 1;
+    #ifdef USE_GPU
+        max_batch = batch_size;
+    #endif
+
+    for(int idx=0; idx < std::min(max_batch, (int)frame_queue.size()); idx++){
+        AudioFrame *frame = frame_queue.front();
+        int length = frame->GetLen();
+        if(length >= max_sent){
+            if(bs_acc == 0){
+                bs_acc++;
+                frame_batch.push(frame);
+                frame_queue.pop();                
+            }
+            break;
+        }
+        max_len = std::max(max_len, frame->GetLen());
+        if(max_len*(bs_acc+1) > max_acc){
+            break;
+        }
+        bs_acc++;
+        frame_batch.push(frame);
+        frame_queue.pop();
+    }
+
+    batch_in = (int)frame_batch.size();
+    if (batch_in == 0){
+        return 0;
+    } else{
+        // init
+        dout = new float*[batch_in];
+        len = new int[batch_in];
+        flag = new int[batch_in];
+        start_time = new float[batch_in];
+
+        for(int idx=0; idx < batch_in; idx++){
+            AudioFrame *frame = frame_batch.front();
+            frame_batch.pop();
+
+            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
+            dout[idx] = speech_data + frame->GetStart();
+            len[idx] = frame->GetLen();
+            delete frame;
+            flag[idx] = S_END;
+        }
+        return 1;
+    }
+}
+
 void Audio::Padding()
 {
     float num_samples = speech_len;
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 3323b10..d235e6f 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -250,7 +250,7 @@
 
 		std::string cur_stamp = "[";
 		std::string lang = (offline_stream->asr_handle)->GetLang();
-		while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
+		while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
 			// dec reset
 			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
 			if (wfst_decoder){
@@ -372,7 +372,7 @@
 
 		std::string cur_stamp = "[";
 		std::string lang = (offline_stream->asr_handle)->GetLang();
-		while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
+		while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
 			// dec reset
 			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
 			if (wfst_decoder){

--
Gitblit v1.9.1