雾聪
2024-04-19 63a70a00f7c9f162e8d7b3e330516438fb8cd87b
add dynamic batch for fetch
3个文件已修改
62 ■■■■■ 已修改文件
runtime/onnxruntime/include/audio.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/audio.cpp 57 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/funasrruntime.cpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/audio.h
@@ -84,6 +84,7 @@
    int Fetch(float *&dout, int &len, int &flag);
    int Fetch(float *&dout, int &len, int &flag, float &start_time);
    int Fetch(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
    int FetchDynamic(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
    void Padding();
    void Split(OfflineStream* offline_streamj);
    void CutSplit(OfflineStream* offline_streamj, std::vector<int> &index_vector);
runtime/onnxruntime/src/audio.cpp
@@ -1049,6 +1049,63 @@
    }
}
int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
{
    //compute batch size
    queue<AudioFrame *> frame_batch;
    int max_acc = 300*1000*seg_sample;
    int max_sent = 60*1000*seg_sample;
    int bs_acc = 0;
    int max_len = 0;
    int max_batch = 1;
    #ifdef USE_GPU
        max_batch = batch_size;
    #endif
    for(int idx=0; idx < std::min(max_batch, (int)frame_queue.size()); idx++){
        AudioFrame *frame = frame_queue.front();
        int length = frame->GetLen();
        if(length >= max_sent){
            if(bs_acc == 0){
                bs_acc++;
                frame_batch.push(frame);
                frame_queue.pop();
            }
            break;
        }
        max_len = std::max(max_len, frame->GetLen());
        if(max_len*(bs_acc+1) > max_acc){
            break;
        }
        bs_acc++;
        frame_batch.push(frame);
        frame_queue.pop();
    }
    batch_in = (int)frame_batch.size();
    if (batch_in == 0){
        return 0;
    } else{
        // init
        dout = new float*[batch_in];
        len = new int[batch_in];
        flag = new int[batch_in];
        start_time = new float[batch_in];
        for(int idx=0; idx < batch_in; idx++){
            AudioFrame *frame = frame_batch.front();
            frame_batch.pop();
            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
            dout[idx] = speech_data + frame->GetStart();
            len[idx] = frame->GetLen();
            delete frame;
            flag[idx] = S_END;
        }
        return 1;
    }
}
void Audio::Padding()
{
    float num_samples = speech_len;
runtime/onnxruntime/src/funasrruntime.cpp
@@ -250,7 +250,7 @@
        std::string cur_stamp = "[";
        std::string lang = (offline_stream->asr_handle)->GetLang();
        while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
        while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
            // dec reset
            funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
            if (wfst_decoder){
@@ -372,7 +372,7 @@
        std::string cur_stamp = "[";
        std::string lang = (offline_stream->asr_handle)->GetLang();
        while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
        while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
            // dec reset
            funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
            if (wfst_decoder){