雾聪
2024-03-29 cad1979179a110b154568dd6281035ece9aaf0b8
add batch for offline-stream
3个文件已修改
179 ■■■■■ 已修改文件
runtime/onnxruntime/src/audio.cpp 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/funasrruntime.cpp 146 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/offline-stream.cpp 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/audio.cpp
@@ -1023,6 +1023,32 @@
    }
}
int Audio::Fetch(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
{
    batch_in = std::min((int)frame_queue.size(), batch_size);
    if (batch_in == 0){
        return 0;
    } else{
        // init
        dout = new float*[batch_in];
        len = new int[batch_in];
        flag = new int[batch_in];
        start_time = new float[batch_in];
        for(int idx=0; idx < batch_in; idx++){
            AudioFrame *frame = frame_queue.front();
            frame_queue.pop();
            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
            dout[idx] = speech_data + frame->GetStart();
            len[idx] = frame->GetLen();
            delete frame;
            flag[idx] = S_END;
        }
        return 1;
    }
}
void Audio::Padding()
{
    float num_samples = speech_len;
runtime/onnxruntime/src/funasrruntime.cpp
@@ -33,9 +33,9 @@
        return mm;
    }
    _FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu)
    _FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
    {
        funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu);
        funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu, batch_size);
        return mm;
    }
@@ -74,16 +74,11 @@
        if(p_result->snippet_time == 0){
            return p_result;
        }
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        while (audio.Fetch(buff, len, flag) > 0) {
            string msg = recog_obj->Forward(buff, len, input_finished);
            p_result->msg += msg;
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        return p_result;
    }
@@ -109,8 +104,6 @@
        float* buff;
        int len;
        int flag = 0;
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
        p_result->snippet_time = audio.GetTimeLen();
        if(p_result->snippet_time == 0){
@@ -119,11 +112,7 @@
        while (audio.Fetch(buff, len, flag) > 0) {
            string msg = recog_obj->Forward(buff, len, true);
            p_result->msg += msg;
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        return p_result;
    }
@@ -248,37 +237,51 @@
            audio.CutSplit(offline_stream);
        }
        float* buff;
        int len;
        int flag = 0;
        float** buff;
        int* len;
        int* flag;
        float* start_time;
        int batch_size = offline_stream->asr_handle->GetBatchSize();
        int batch_in = 0;
        float start_time = 0.0;
        std::string cur_stamp = "[";
        std::string lang = (offline_stream->asr_handle)->GetLang();
        while (audio.Fetch(buff, len, flag, start_time) > 0) {
        while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
            // dec reset
            funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
            if (wfst_decoder){
                wfst_decoder->StartUtterance();
            }
            string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
            std::vector<std::string> msg_vec = funasr::split(msg, '|');
            if(msg_vec.size()==0){
                continue;
            }
            if(lang == "en-bpe" && p_result->msg != ""){
                p_result->msg += " ";
            }
            p_result->msg += msg_vec[0];
            //timestamp
            if(msg_vec.size() > 1){
                std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
                for(int i=0; i<msg_stamp.size()-1; i+=2){
                    float begin = std::stof(msg_stamp[i])+start_time;
                    float end = std::stof(msg_stamp[i+1])+start_time;
                    cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
            vector<string> msgs = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
            for(int idx=0; idx<batch_in; idx++){
                string msg = msgs[idx];
                std::vector<std::string> msg_vec = funasr::split(msg, '|');
                if(msg_vec.size()==0){
                    continue;
                }
                if(lang == "en-bpe" && p_result->msg != ""){
                    p_result->msg += " ";
                }
                p_result->msg += msg_vec[0];
                //timestamp
                if(msg_vec.size() > 1){
                    std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
                    for(int i=0; i<msg_stamp.size()-1; i+=2){
                        float begin = std::stof(msg_stamp[i])+start_time[idx];
                        float end = std::stof(msg_stamp[i+1])+start_time[idx];
                        cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
                    }
                }
            }
            // release
            delete[] buff;
            buff = nullptr;
            delete[] len;
            len = nullptr;
            delete[] flag;
            flag = nullptr;
            delete[] start_time;
            start_time = nullptr;
        }
        if(cur_stamp != "["){
            cur_stamp.erase(cur_stamp.length() - 1);
@@ -341,42 +344,51 @@
            audio.CutSplit(offline_stream);
        }
        float* buff;
        int len;
        int flag = 0;
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        float start_time = 0.0;
        float** buff;
        int* len;
        int* flag;
        float* start_time;
        int batch_size = offline_stream->asr_handle->GetBatchSize();
        int batch_in = 0;
        std::string cur_stamp = "[";
        std::string lang = (offline_stream->asr_handle)->GetLang();
        while (audio.Fetch(buff, len, flag, start_time) > 0) {
        while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
            // dec reset
            funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
            if (wfst_decoder){
                wfst_decoder->StartUtterance();
            }
            string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
            std::vector<std::string> msg_vec = funasr::split(msg, '|');
            if(msg_vec.size()==0){
                continue;
            }
            if(lang == "en-bpe" && p_result->msg != ""){
                p_result->msg += " ";
            }
            p_result->msg += msg_vec[0];
            //timestamp
            if(msg_vec.size() > 1){
                std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
                for(int i=0; i<msg_stamp.size()-1; i+=2){
                    float begin = std::stof(msg_stamp[i])+start_time;
                    float end = std::stof(msg_stamp[i+1])+start_time;
                    cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
            vector<string> msgs = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
            for(int idx=0; idx<batch_in; idx++){
                string msg = msgs[idx];
                std::vector<std::string> msg_vec = funasr::split(msg, '|');
                if(msg_vec.size()==0){
                    continue;
                }
                if(lang == "en-bpe" && p_result->msg != ""){
                    p_result->msg += " ";
                }
                p_result->msg += msg_vec[0];
                //timestamp
                if(msg_vec.size() > 1){
                    std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
                    for(int i=0; i<msg_stamp.size()-1; i+=2){
                        float begin = std::stof(msg_stamp[i])+start_time[idx];
                        float end = std::stof(msg_stamp[i+1])+start_time[idx];
                        cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
                    }
                }
            }
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
            // release
            delete[] buff;
            buff = nullptr;
            delete[] len;
            len = nullptr;
            delete[] flag;
            flag = nullptr;
            delete[] start_time;
            start_time = nullptr;
        }
        if(cur_stamp != "["){
            cur_stamp.erase(cur_stamp.length() - 1);
@@ -513,8 +525,14 @@
            if (wfst_decoder){
                wfst_decoder->StartUtterance();
            }
            string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
            float** buff;
            int* len;
            buff = new float*[1];
            len = new int[1];
            buff[0] = frame->data;
            len[0] = frame->len;
            vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
            string msg = msgs.size()>0?msgs[0]:"";
            std::vector<std::string> msg_vec = funasr::split(msg, '|');  // split with timestamp
            if(msg_vec.size()==0){
                continue;
runtime/onnxruntime/src/offline-stream.cpp
@@ -1,7 +1,7 @@
#include "precomp.h"
namespace funasr {
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu)
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
    // VAD model
    if(model_path.find(VAD_DIR) != model_path.end()){
@@ -38,6 +38,7 @@
        if(use_gpu){
            #ifdef USE_GPU
            asr_handle = make_unique<ParaformerTorch>();
            asr_handle->SetBatchSize(batch_size);
            #else
            LOG(ERROR) <<"GPU is not supported! CPU will be used! If you want to use GPU, please add -DGPU=ON when cmake";
            asr_handle = make_unique<Paraformer>();
@@ -135,10 +136,10 @@
#endif
}
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu)
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
    OfflineStream *mm;
    mm = new OfflineStream(model_path, thread_num, use_gpu);
    mm = new OfflineStream(model_path, thread_num, use_gpu, batch_size);
    return mm;
}