From 63a70a00f7c9f162e8d7b3e330516438fb8cd87b Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 19 四月 2024 16:08:05 +0800
Subject: [PATCH] add dynamic batch for fetch
---
runtime/onnxruntime/src/audio.cpp | 57 ++++++++++++++++++++++++++++
runtime/onnxruntime/src/funasrruntime.cpp | 4 +-
runtime/onnxruntime/include/audio.h | 1
3 files changed, 60 insertions(+), 2 deletions(-)
diff --git a/runtime/onnxruntime/include/audio.h b/runtime/onnxruntime/include/audio.h
index ba5f9e0..3011050 100644
--- a/runtime/onnxruntime/include/audio.h
+++ b/runtime/onnxruntime/include/audio.h
@@ -84,6 +84,7 @@
int Fetch(float *&dout, int &len, int &flag);
int Fetch(float *&dout, int &len, int &flag, float &start_time);
int Fetch(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
+ int FetchDynamic(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
void Padding();
void Split(OfflineStream* offline_streamj);
void CutSplit(OfflineStream* offline_streamj, std::vector<int> &index_vector);
diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp
index f1285af..a5a44ca 100644
--- a/runtime/onnxruntime/src/audio.cpp
+++ b/runtime/onnxruntime/src/audio.cpp
@@ -1049,6 +1049,63 @@
}
}
+int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
+{
+ //compute batch size
+ queue<AudioFrame *> frame_batch;
+ int max_acc = 300*1000*seg_sample;
+ int max_sent = 60*1000*seg_sample;
+ int bs_acc = 0;
+ int max_len = 0;
+ int max_batch = 1;
+ #ifdef USE_GPU
+ max_batch = batch_size;
+ #endif
+
+ for(int idx=0; idx < std::min(max_batch, (int)frame_queue.size()); idx++){
+ AudioFrame *frame = frame_queue.front();
+ int length = frame->GetLen();
+ if(length >= max_sent){
+ if(bs_acc == 0){
+ bs_acc++;
+ frame_batch.push(frame);
+ frame_queue.pop();
+ }
+ break;
+ }
+ max_len = std::max(max_len, frame->GetLen());
+ if(max_len*(bs_acc+1) > max_acc){
+ break;
+ }
+ bs_acc++;
+ frame_batch.push(frame);
+ frame_queue.pop();
+ }
+
+ batch_in = (int)frame_batch.size();
+ if (batch_in == 0){
+ return 0;
+ } else{
+ // init
+ dout = new float*[batch_in];
+ len = new int[batch_in];
+ flag = new int[batch_in];
+ start_time = new float[batch_in];
+
+ for(int idx=0; idx < batch_in; idx++){
+ AudioFrame *frame = frame_batch.front();
+ frame_batch.pop();
+
+ start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
+ dout[idx] = speech_data + frame->GetStart();
+ len[idx] = frame->GetLen();
+ delete frame;
+ flag[idx] = S_END;
+ }
+ return 1;
+ }
+}
+
void Audio::Padding()
{
float num_samples = speech_len;
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 3323b10..d235e6f 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -250,7 +250,7 @@
std::string cur_stamp = "[";
std::string lang = (offline_stream->asr_handle)->GetLang();
- while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
+ while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
// dec reset
funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
if (wfst_decoder){
@@ -372,7 +372,7 @@
std::string cur_stamp = "[";
std::string lang = (offline_stream->asr_handle)->GetLang();
- while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
+ while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
// dec reset
funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
if (wfst_decoder){
--
Gitblit v1.9.1