Yabin Li
2023-12-13 a0048dc766f233827f2e7b5ebed0a0e22fae44b1
8k (#1174) (#1175)

* Adaptive 8K

* fix FfmpegLoad 8k

Co-authored-by: cdevelop <cdevelop@qq.com>
12个文件已修改
101 ■■■■■ 已修改文件
runtime/onnxruntime/include/audio.h 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/model.h 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/vad-model.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/audio.cpp 44 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/fsmn-vad-online.cpp 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/fsmn-vad-online.h 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/fsmn-vad.h 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/funasrruntime.cpp 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer-online.cpp 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer-online.h 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer.cpp 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer.h 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/audio.h
@@ -52,10 +52,11 @@
    queue<AudioFrame *> frame_queue;
    queue<AudioFrame *> asr_online_queue;
    queue<AudioFrame *> asr_offline_queue;
    int dest_sample_rate;
  public:
    Audio(int data_type);
    Audio(int data_type, int size);
    Audio(int model_sample_rate,int data_type);
    Audio(int model_sample_rate,int data_type, int size);
    ~Audio();
    void Disp();
    void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);
runtime/onnxruntime/include/model.h
@@ -23,6 +23,8 @@
    virtual void InitSegDict(const std::string &seg_dict_model){};
    virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
    virtual std::string GetLang(){return "";};
    virtual int GetAsrSampleRate() = 0;
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
runtime/onnxruntime/include/vad-model.h
@@ -12,6 +12,7 @@
    virtual ~VadModel(){};
    virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
    virtual std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true)=0;
    virtual int GetVadSampleRate() = 0;
};
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
runtime/onnxruntime/src/audio.cpp
@@ -193,18 +193,28 @@
    return 0;
}
Audio::Audio(int data_type) : data_type(data_type)
Audio::Audio(int data_type) : dest_sample_rate(MODEL_SAMPLE_RATE), data_type(data_type)
{
    speech_buff = NULL;
    speech_data = NULL;
    align_size = 1360;
    seg_sample = dest_sample_rate / 1000;
}
Audio::Audio(int data_type, int size) : data_type(data_type)
Audio::Audio(int model_sample_rate, int data_type) : dest_sample_rate(model_sample_rate), data_type(data_type)
{
    speech_buff = NULL;
    speech_data = NULL;
    align_size = 1360;
    seg_sample = dest_sample_rate / 1000;
}
Audio::Audio(int model_sample_rate, int data_type, int size) : dest_sample_rate(model_sample_rate), data_type(data_type)
{
    speech_buff = NULL;
    speech_data = NULL;
    align_size = (float)size;
    seg_sample = dest_sample_rate / 1000;
}
Audio::~Audio()
@@ -222,12 +232,12 @@
void Audio::Disp()
{
    LOG(INFO) << "Audio time is " << (float)speech_len / MODEL_SAMPLE_RATE << " s. len is " << speech_len;
    LOG(INFO) << "Audio time is " << (float)speech_len / dest_sample_rate << " s. len is " << speech_len;
}
float Audio::GetTimeLen()
{
    return (float)speech_len / MODEL_SAMPLE_RATE;
    return (float)speech_len / dest_sample_rate;
}
void Audio::WavResample(int32_t sampling_rate, const float *waveform,
@@ -237,13 +247,13 @@
              << "   in_sample_rate: "<< sampling_rate << "\n"
              << "   output_sample_rate: " << static_cast<int32_t>(MODEL_SAMPLE_RATE);
    float min_freq =
        std::min<int32_t>(sampling_rate, MODEL_SAMPLE_RATE);
        std::min<int32_t>(sampling_rate, dest_sample_rate);
    float lowpass_cutoff = 0.99 * 0.5 * min_freq;
    int32_t lowpass_filter_width = 6;
    auto resampler = std::make_unique<LinearResample>(
          sampling_rate, MODEL_SAMPLE_RATE, lowpass_cutoff, lowpass_filter_width);
          sampling_rate, dest_sample_rate, lowpass_cutoff, lowpass_filter_width);
    std::vector<float> samples;
    resampler->Resample(waveform, n, true, &samples);
    //reset speech_data
@@ -311,7 +321,7 @@
        nullptr, // allocate a new context
        AV_CH_LAYOUT_MONO, // output channel layout (stereo)
        AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
        16000, // output sample rate (same as input)
        dest_sample_rate, // output sample rate (same as input)
        av_get_default_channel_layout(codecContext->channels), // input channel layout
        codecContext->sample_fmt, // input sample format
        codecContext->sample_rate, // input sample rate
@@ -347,7 +357,7 @@
                    int in_samples = frame->nb_samples;
                    uint8_t **in_data = frame->extended_data;
                    int out_samples = av_rescale_rnd(in_samples,
                                                    16000,
                                                    dest_sample_rate,
                                                    codecContext->sample_rate,
                                                    AV_ROUND_DOWN);
                    
@@ -494,7 +504,7 @@
        nullptr, // allocate a new context
        AV_CH_LAYOUT_MONO, // output channel layout (stereo)
        AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
        16000, // output sample rate (same as input)
        dest_sample_rate, // output sample rate (same as input)
        av_get_default_channel_layout(codecContext->channels), // input channel layout
        codecContext->sample_fmt, // input sample format
        codecContext->sample_rate, // input sample rate
@@ -532,7 +542,7 @@
                    int in_samples = frame->nb_samples;
                    uint8_t **in_data = frame->extended_data;
                    int out_samples = av_rescale_rnd(in_samples,
                                                    16000,
                                                    dest_sample_rate,
                                                    codecContext->sample_rate,
                                                    AV_ROUND_DOWN);
                    
@@ -666,7 +676,7 @@
        }
        //resample
        if(*sampling_rate != MODEL_SAMPLE_RATE){
        if(*sampling_rate != dest_sample_rate){
            WavResample(*sampling_rate, speech_data, speech_len);
        }
@@ -752,7 +762,7 @@
        }
        
        //resample
        if(*sampling_rate != MODEL_SAMPLE_RATE){
        if(*sampling_rate != dest_sample_rate){
            WavResample(*sampling_rate, speech_data, speech_len);
        }
@@ -795,7 +805,7 @@
        }
        
        //resample
        if(*sampling_rate != MODEL_SAMPLE_RATE){
        if(*sampling_rate != dest_sample_rate){
            WavResample(*sampling_rate, speech_data, speech_len);
        }
@@ -840,7 +850,7 @@
        }
        
        //resample
        if(*sampling_rate != MODEL_SAMPLE_RATE){
        if(*sampling_rate != dest_sample_rate){
            WavResample(*sampling_rate, speech_data, speech_len);
        }
@@ -898,7 +908,7 @@
        }
        //resample
        if(*sampling_rate != MODEL_SAMPLE_RATE){
        if(*sampling_rate != dest_sample_rate){
            WavResample(*sampling_rate, speech_data, speech_len);
        }
@@ -1009,7 +1019,7 @@
        AudioFrame *frame = frame_queue.front();
        frame_queue.pop();
        start_time = (float)(frame->GetStart())/MODEL_SAMPLE_RATE;
        start_time = (float)(frame->GetStart())/ dest_sample_rate;
        dout = speech_data + frame->GetStart();
        len = frame->GetLen();
        delete frame;
@@ -1248,7 +1258,7 @@
    }
    // erase all_samples
    int vector_cache = MODEL_SAMPLE_RATE*2;
    int vector_cache = dest_sample_rate*2;
    if(speech_offline_start == -1){
        if(all_samples.size() > vector_cache){
            int erase_num = all_samples.size() - vector_cache;
runtime/onnxruntime/src/fsmn-vad-online.cpp
@@ -187,8 +187,11 @@
    vad_max_len_ = vad_max_len;
    vad_speech_noise_thres_ = vad_speech_noise_thres;
    frame_sample_length_ = vad_sample_rate_ / 1000 * 25;;
    frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10;
    // 2pass
    audio_handle = make_unique<Audio>(1);
    audio_handle = make_unique<Audio>(vad_sample_rate,1);
}
FsmnVadOnline::~FsmnVadOnline() {
runtime/onnxruntime/src/fsmn-vad-online.h
@@ -21,6 +21,8 @@
    std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
    void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
    void Reset();
    int GetVadSampleRate() { return vad_sample_rate_; };
    // 2pass
    std::unique_ptr<Audio> audio_handle = nullptr;
runtime/onnxruntime/src/fsmn-vad.h
@@ -28,6 +28,8 @@
        std::vector<std::vector<float>> *in_cache,
        bool is_final);
    void Reset();
    int GetVadSampleRate() { return vad_sample_rate_; };
    
    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
    Ort::Env env_;
runtime/onnxruntime/src/funasrruntime.cpp
@@ -57,7 +57,7 @@
        if (!recog_obj)
            return nullptr;
        funasr::Audio audio(1);
        funasr::Audio audio(recog_obj->GetAsrSampleRate(),1);
        if(wav_format == "pcm" || wav_format == "PCM"){
            if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
                return nullptr;
@@ -93,7 +93,7 @@
        if (!recog_obj)
            return nullptr;
        funasr::Audio audio(1);
        funasr::Audio audio(recog_obj->GetAsrSampleRate(),1);
        if(funasr::is_target_file(sz_filename, "wav")){
            int32_t sampling_rate_ = -1;
            if(!audio.LoadWav(sz_filename, &sampling_rate_))
@@ -134,7 +134,7 @@
        if (!vad_obj)
            return nullptr;
        funasr::Audio audio(1);
        funasr::Audio audio(vad_obj->GetVadSampleRate(),1);
        if(wav_format == "pcm" || wav_format == "PCM"){
            if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
                return nullptr;
@@ -162,7 +162,7 @@
        if (!vad_obj)
            return nullptr;
        funasr::Audio audio(1);
        funasr::Audio audio(vad_obj->GetVadSampleRate(),1);
        if(funasr::is_target_file(sz_filename, "wav")){
            int32_t sampling_rate_ = -1;
            if(!audio.LoadWav(sz_filename, &sampling_rate_))
@@ -222,7 +222,7 @@
        if (!offline_stream)
            return nullptr;
        funasr::Audio audio(1);
        funasr::Audio audio(offline_stream->asr_handle->GetAsrSampleRate(),1);
        try{
            if(wav_format == "pcm" || wav_format == "PCM"){
                if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
@@ -314,7 +314,7 @@
        if (!offline_stream)
            return nullptr;
        
        funasr::Audio audio(1);
        funasr::Audio audio((offline_stream->asr_handle)->GetAsrSampleRate(),1);
        try{
            if(funasr::is_target_file(sz_filename, "wav")){
                int32_t sampling_rate_ = -1;
runtime/onnxruntime/src/paraformer-online.cpp
@@ -61,7 +61,11 @@
    for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
        fsmn_init_cache_.emplace_back(0);
    }
    chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
    chunk_len = chunk_size[1]*frame_shift*lfr_n*para_handle_->asr_sample_rate/1000;
    frame_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_length;
    frame_shift_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_shift;
}
void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
@@ -489,7 +493,7 @@
        if(is_first_chunk){
            is_first_chunk = false;
        }
        ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished);
        ExtractFeats(para_handle_->asr_sample_rate, wav_feats, waves, input_finished);
        if(wav_feats.size() == 0){
            return result;
        }
runtime/onnxruntime/src/paraformer-online.h
@@ -111,6 +111,9 @@
        string ForwardChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
        string Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
        string Rescoring();
        int GetAsrSampleRate() { return para_handle_->asr_sample_rate; };
        // 2pass
        std::string online_res;
        int chunk_len;
runtime/onnxruntime/src/paraformer.cpp
@@ -19,10 +19,11 @@
// offline
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
    LoadConfigFromYaml(am_config.c_str());
    // knf options
    fbank_opts_.frame_opts.dither = 0;
    fbank_opts_.mel_opts.num_bins = n_mels;
    fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
    fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
    fbank_opts_.frame_opts.window_type = window_type;
    fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
    fbank_opts_.frame_opts.frame_length_ms = frame_length;
@@ -65,7 +66,6 @@
    for (auto& item : m_strOutputNames)
        m_szOutputNames.push_back(item.c_str());
    vocab = new Vocab(am_config.c_str());
    LoadConfigFromYaml(am_config.c_str());
    phone_set_ = new PhoneSet(am_config.c_str());
    LoadCmvn(am_cmvn.c_str());
}
@@ -77,7 +77,7 @@
    // knf options
    fbank_opts_.frame_opts.dither = 0;
    fbank_opts_.mel_opts.num_bins = n_mels;
    fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
    fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
    fbank_opts_.frame_opts.window_type = window_type;
    fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
    fbank_opts_.frame_opts.frame_length_ms = frame_length;
@@ -216,6 +216,9 @@
    }
    try{
        YAML::Node frontend_conf = config["frontend_conf"];
        this->asr_sample_rate = frontend_conf["fs"].as<int>();
        YAML::Node lang_conf = config["lang"];
        if (lang_conf.IsDefined()){
            language = lang_conf.as<string>();
@@ -258,6 +261,9 @@
        this->cif_threshold = predictor_conf["threshold"].as<double>();
        this->tail_alphas = predictor_conf["tail_threshold"].as<double>();
        this->asr_sample_rate = frontend_conf["fs"].as<int>();
    }catch(exception const &e){
        LOG(ERROR) << "Error when load argument from vad config YAML.";
        exit(-1);
runtime/onnxruntime/src/paraformer.h
@@ -57,7 +57,7 @@
        string Rescoring();
        string GetLang(){return language;};
        int GetAsrSampleRate() { return asr_sample_rate; };
        void StartUtterance();
        void EndUtterance();
        void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@@ -107,8 +107,7 @@
        int fsmn_dims = 512;
        float cif_threshold = 1.0;
        float tail_alphas = 0.45;
        int asr_sample_rate = MODEL_SAMPLE_RATE;
    };
} // namespace funasr