Yabin Li
2023-04-21 9ddfac27d2a9ec6b136ab92539f5e786647def8f
Merge pull request #397 from zhuzizyf/patch-1

Update e2e_vad.h
1个文件已修改
38 ■■■■ 已修改文件
funasr/runtime/onnxruntime/src/e2e_vad.h 38 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/e2e_vad.h
@@ -294,8 +294,8 @@
    std::vector<std::vector<int>>
    operator()(const std::vector<std::vector<float>> &score, const std::vector<float> &waveform, bool is_final = false,
               int max_end_sil = 800, int max_single_segment_time = 15000, float speech_noise_thres = 0.9,
               int sample_rate = 16000) {
               bool online = false, int max_end_sil = 800, int max_single_segment_time = 15000,
               float speech_noise_thres = 0.8, int sample_rate = 16000) {
        max_end_sil_frame_cnt_thresh = max_end_sil - vad_opts.speech_to_sil_time_thres;
        this->waveform = waveform;
        this->vad_opts.max_single_segment_time = max_single_segment_time;
@@ -309,19 +309,22 @@
        } else {
            DetectLastFrames();
        }
        //    std::vector<std::vector<int>> segments;
        //    for (size_t batch_num = 0; batch_num < score.size(); batch_num++) {
        std::vector<std::vector<int>> segment_batch;
        if (output_data_buf.size() > 0) {
            for (size_t i = output_data_buf_offset; i < output_data_buf.size(); i++) {
              int start_ms;
              int end_ms;
              if (online) {
                if (!output_data_buf[i].contain_seg_start_point) {
                    continue;
                }
                if (!next_seg && !output_data_buf[i].contain_seg_end_point) {
                    continue;
                }
                int start_ms = next_seg ? output_data_buf[i].start_ms : -1;
                int end_ms;
                start_ms = next_seg ? output_data_buf[i].start_ms : -1;
                if (output_data_buf[i].contain_seg_end_point) {
                    end_ms = output_data_buf[i].end_ms;
                    next_seg = true;
@@ -330,12 +333,20 @@
                    end_ms = -1;
                    next_seg = false;
                }
              } else {
                if (!is_final &&
                    (!output_data_buf[i].contain_seg_start_point || !output_data_buf[i].contain_seg_end_point)) {
                  continue;
                }
                start_ms = output_data_buf[i].start_ms;
                end_ms = output_data_buf[i].end_ms;
                output_data_buf_offset += 1;
              }
                std::vector<int> segment = {start_ms, end_ms};
                segment_batch.push_back(segment);
            }
        }
        //    }
        if (is_final) {
            AllResetDetection();
        }
@@ -444,15 +455,22 @@
    }
    void PopDataBufTillFrame(int frame_idx) {
        while (data_buf_start_frame < frame_idx) {
            int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
            if (data_buf.size() >= frame_sample_length) {
      int start_pos=-1;
      int data_length= data_buf.size();
      while (data_buf_start_frame < frame_idx) {
        if (data_length >= frame_sample_length) {
                data_buf_start_frame += 1;
                data_buf.erase(data_buf.begin(), data_buf.begin() + frame_sample_length);
          start_pos= data_buf_start_frame* frame_sample_length;
          data_length=data_buf_all.size()-start_pos;
            } else {
                break;
            }
        }
      if (start_pos!=-1){
        data_buf.resize(data_length);
        std::copy(data_buf_all.begin() + start_pos, data_buf_all.end(), data_buf.begin());
      }
    }
    void PopDataToOutputBuf(int start_frm, int frm_cnt, bool first_frm_is_start_point, bool last_frm_is_end_point,