From ade08818b7a579aac75182b906a5bd3b8126411c Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 27 五月 2024 15:46:26 +0800
Subject: [PATCH] Merge branch 'dev_batch' into main
---
runtime/onnxruntime/src/audio.cpp | 106 +++++++++++++++++++++++++++++++++++++++++++++++++++-
1 files changed, 103 insertions(+), 3 deletions(-)
diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp
index 0135ab4..22a9ecd 100644
--- a/runtime/onnxruntime/src/audio.cpp
+++ b/runtime/onnxruntime/src/audio.cpp
@@ -1023,6 +1023,90 @@
}
}
+int Audio::Fetch(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
+{
+ batch_in = std::min((int)frame_queue.size(), batch_size);
+ if (batch_in == 0){
+ return 0;
+ } else{
+ // init
+ dout = new float*[batch_in];
+ len = new int[batch_in];
+ flag = new int[batch_in];
+ start_time = new float[batch_in];
+
+ for(int idx=0; idx < batch_in; idx++){
+ AudioFrame *frame = frame_queue.front();
+ frame_queue.pop();
+
+ start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
+ dout[idx] = speech_data + frame->GetStart();
+ len[idx] = frame->GetLen();
+ delete frame;
+ flag[idx] = S_END;
+ }
+ return 1;
+ }
+}
+
+int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
+{
+ //compute batch size
+ queue<AudioFrame *> frame_batch;
+ int max_acc = 300*1000*seg_sample;
+ int max_sent = 60*1000*seg_sample;
+ int bs_acc = 0;
+ int max_len = 0;
+ int max_batch = 1;
+ #ifdef USE_GPU
+ max_batch = batch_size;
+ #endif
+ max_batch = std::min(max_batch, (int)frame_queue.size());
+
+ for(int idx=0; idx < max_batch; idx++){
+ AudioFrame *frame = frame_queue.front();
+ int length = frame->GetLen();
+ if(length >= max_sent){
+ if(bs_acc == 0){
+ bs_acc++;
+ frame_batch.push(frame);
+ frame_queue.pop();
+ }
+ break;
+ }
+ max_len = std::max(max_len, frame->GetLen());
+ if(max_len*(bs_acc+1) > max_acc){
+ break;
+ }
+ bs_acc++;
+ frame_batch.push(frame);
+ frame_queue.pop();
+ }
+
+ batch_in = (int)frame_batch.size();
+ if (batch_in == 0){
+ return 0;
+ } else{
+ // init
+ dout = new float*[batch_in];
+ len = new int[batch_in];
+ flag = new int[batch_in];
+ start_time = new float[batch_in];
+
+ for(int idx=0; idx < batch_in; idx++){
+ AudioFrame *frame = frame_batch.front();
+ frame_batch.pop();
+
+ start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
+ dout[idx] = speech_data + frame->GetStart();
+ len[idx] = frame->GetLen();
+ delete frame;
+ flag[idx] = S_END;
+ }
+ return 1;
+ }
+}
+
void Audio::Padding()
{
float num_samples = speech_len;
@@ -1085,7 +1169,7 @@
}
}
-void Audio::CutSplit(OfflineStream* offline_stream)
+void Audio::CutSplit(OfflineStream* offline_stream, std::vector<int> &index_vector)
{
std::unique_ptr<VadModel> vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(offline_stream->vad_handle).get());
AudioFrame *frame;
@@ -1112,6 +1196,7 @@
}
int speech_start_i = -1, speech_end_i =-1;
+ std::vector<AudioFrame*> vad_frames;
for(vector<int> vad_segment:vad_segments)
{
if(vad_segment.size() != 2){
@@ -1126,16 +1211,31 @@
}
if(speech_start_i!=-1 && speech_end_i!=-1){
- frame = new AudioFrame();
int start = speech_start_i*seg_sample;
int end = speech_end_i*seg_sample;
+ frame = new AudioFrame(end-start);
frame->SetStart(start);
frame->SetEnd(end);
- frame_queue.push(frame);
+ vad_frames.push_back(frame);
frame = nullptr;
speech_start_i=-1;
speech_end_i=-1;
}
+
+ }
+ // sort
+ {
+ index_vector.clear();
+ index_vector.resize(vad_frames.size());
+ for (int i = 0; i < index_vector.size(); ++i) {
+ index_vector[i] = i;
+ }
+ std::sort(index_vector.begin(), index_vector.end(), [&vad_frames](const int a, const int b) {
+ return vad_frames[a]->len < vad_frames[b]->len;
+ });
+ for (int idx : index_vector) {
+ frame_queue.push(vad_frames[idx]);
+ }
}
}
--
Gitblit v1.9.1