zhifu gao
2024-06-04 3b0526e7be3565c42007313b90a018a2f8c8dff1
runtime/onnxruntime/src/audio.cpp
@@ -1023,6 +1023,90 @@
    }
}
int Audio::Fetch(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
{
    batch_in = std::min((int)frame_queue.size(), batch_size);
    if (batch_in == 0){
        return 0;
    } else{
        // init
        dout = new float*[batch_in];
        len = new int[batch_in];
        flag = new int[batch_in];
        start_time = new float[batch_in];
        for(int idx=0; idx < batch_in; idx++){
            AudioFrame *frame = frame_queue.front();
            frame_queue.pop();
            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
            dout[idx] = speech_data + frame->GetStart();
            len[idx] = frame->GetLen();
            delete frame;
            flag[idx] = S_END;
        }
        return 1;
    }
}
int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
{
    //compute batch size
    queue<AudioFrame *> frame_batch;
    int max_acc = 300*1000*seg_sample;
    int max_sent = 60*1000*seg_sample;
    int bs_acc = 0;
    int max_len = 0;
    int max_batch = 1;
    #ifdef USE_GPU
        max_batch = batch_size;
    #endif
    max_batch = std::min(max_batch, (int)frame_queue.size());
    for(int idx=0; idx < max_batch; idx++){
        AudioFrame *frame = frame_queue.front();
        int length = frame->GetLen();
        if(length >= max_sent){
            if(bs_acc == 0){
                bs_acc++;
                frame_batch.push(frame);
                frame_queue.pop();
            }
            break;
        }
        max_len = std::max(max_len, frame->GetLen());
        if(max_len*(bs_acc+1) > max_acc){
            break;
        }
        bs_acc++;
        frame_batch.push(frame);
        frame_queue.pop();
    }
    batch_in = (int)frame_batch.size();
    if (batch_in == 0){
        return 0;
    } else{
        // init
        dout = new float*[batch_in];
        len = new int[batch_in];
        flag = new int[batch_in];
        start_time = new float[batch_in];
        for(int idx=0; idx < batch_in; idx++){
            AudioFrame *frame = frame_batch.front();
            frame_batch.pop();
            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
            dout[idx] = speech_data + frame->GetStart();
            len[idx] = frame->GetLen();
            delete frame;
            flag[idx] = S_END;
        }
        return 1;
    }
}
void Audio::Padding()
{
    float num_samples = speech_len;
@@ -1085,7 +1169,7 @@
    }
}
void Audio::CutSplit(OfflineStream* offline_stream)
void Audio::CutSplit(OfflineStream* offline_stream, std::vector<int> &index_vector)
{
    std::unique_ptr<VadModel> vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(offline_stream->vad_handle).get());
    AudioFrame *frame;
@@ -1112,6 +1196,7 @@
    }    
    int speech_start_i = -1, speech_end_i =-1;
    std::vector<AudioFrame*> vad_frames;
    for(vector<int> vad_segment:vad_segments)
    {
        if(vad_segment.size() != 2){
@@ -1126,16 +1211,31 @@
        }
        if(speech_start_i!=-1 && speech_end_i!=-1){
            frame = new AudioFrame();
            int start = speech_start_i*seg_sample;
            int end = speech_end_i*seg_sample;
            frame = new AudioFrame(end-start);
            frame->SetStart(start);
            frame->SetEnd(end);
            frame_queue.push(frame);
            vad_frames.push_back(frame);
            frame = nullptr;
            speech_start_i=-1;
            speech_end_i=-1;
        }
    }
    // sort
    {
        index_vector.clear();
        index_vector.resize(vad_frames.size());
        for (int i = 0; i < index_vector.size(); ++i) {
            index_vector[i] = i;
        }
        std::sort(index_vector.begin(), index_vector.end(), [&vad_frames](const int a, const int b) {
            return vad_frames[a]->len < vad_frames[b]->len;
        });
        for (int idx : index_vector) {
            frame_queue.push(vad_frames[idx]);
        }
    }
}