zhifu gao
2023-05-04 5d38777dc8442c5fc7d27168c505c9f99479d67c
Merge pull request #447 from zhuzizyf/main

Update e2e-vad.h
1个文件已修改
52 ■■■■■ 已修改文件
funasr/runtime/onnxruntime/src/e2e-vad.h 52 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/e2e-vad.h
@@ -1,6 +1,7 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
 * Collaborators: zhuzizyf(China Telecom Shanghai)
*/
#include <utility>
@@ -381,10 +382,11 @@
    int max_end_sil_frame_cnt_thresh;
    float speech_noise_thres;
    std::vector<std::vector<float>> scores;
    int idx_pre_chunk = 0;
    bool max_time_out;
    std::vector<float> decibel;
    std::vector<float> data_buf;
    std::vector<float> data_buf_all;
    int data_buf_size = 0;
    int data_buf_all_size = 0;
    std::vector<float> waveform;
    void AllResetDetection() {
@@ -409,10 +411,11 @@
        max_end_sil_frame_cnt_thresh = vad_opts.max_end_silence_time - vad_opts.speech_to_sil_time_thres;
        speech_noise_thres = vad_opts.speech_noise_thres;
        scores.clear();
        idx_pre_chunk = 0;
        max_time_out = false;
        decibel.clear();
        data_buf.clear();
        data_buf_all.clear();
        int data_buf_size = 0;
        int data_buf_all_size = 0;
        waveform.clear();
        ResetDetection();
    }
@@ -432,18 +435,17 @@
    void ComputeDecibel() {
        int frame_sample_length = int(vad_opts.frame_length_ms * vad_opts.sample_rate / 1000);
        int frame_shift_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
        if (data_buf_all.empty()) {
            data_buf_all = waveform;
            data_buf = data_buf_all;
        if (data_buf_all_size == 0) {
          data_buf_all_size = waveform.size();
          data_buf_size = data_buf_all_size;
        } else {
            data_buf_all.insert(data_buf_all.end(), waveform.begin(), waveform.end());
          data_buf_all_size += waveform.size();
        }
        for (int offset = 0; offset < waveform.size() - frame_sample_length + 1; offset += frame_shift_length) {
            float sum = 0.0;
            for (int i = 0; i < frame_sample_length; i++) {
                sum += waveform[offset + i] * waveform[offset + i];
            }
//      float decibel = 10 * log10(sum + 0.000001);
            this->decibel.push_back(10 * log10(sum + 0.000001));
        }
    }
@@ -451,29 +453,16 @@
    void ComputeScores(const std::vector<std::vector<float>> &scores) {
        vad_opts.nn_eval_block_size = scores.size();
        frm_cnt += scores.size();
        if (this->scores.empty()) {
            this->scores = scores;  // the first calculation
        } else {
            this->scores.insert(this->scores.end(), scores.begin(), scores.end());
        }
        this->scores = scores;
    }
    void PopDataBufTillFrame(int frame_idx) {
      int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
      int start_pos=-1;
      int data_length= data_buf.size();
      while (data_buf_start_frame < frame_idx) {
        if (data_length >= frame_sample_length) {
        if (data_buf_size >= frame_sample_length) {
          data_buf_start_frame += 1;
          start_pos= data_buf_start_frame* frame_sample_length;
          data_length=data_buf_all.size()-start_pos;
        } else {
          break;
          data_buf_size = data_buf_all_size - data_buf_start_frame * frame_sample_length;
        }
      }
      if (start_pos!=-1){
        data_buf.resize(data_length);
        std::copy(data_buf_all.begin() + start_pos, data_buf_all.end(), data_buf.begin());
      }
    }
@@ -487,9 +476,9 @@
            expected_sample_number += int(extra_sample);
        }
        if (end_point_is_sent_end) {
            expected_sample_number = std::max(expected_sample_number, int(data_buf.size()));
            expected_sample_number = std::max(expected_sample_number, data_buf_size);
        }
        if (data_buf.size() < expected_sample_number) {
        if (data_buf_size < expected_sample_number) {
            std::cout << "error in calling pop data_buf\n";
        }
        if (output_data_buf.size() == 0 || first_frm_is_start_point) {
@@ -510,10 +499,10 @@
        } else {
            data_to_pop = int(frm_cnt * vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
        }
        if (data_to_pop > int(data_buf.size())) {
        if (data_to_pop > data_buf_size) {
            std::cout << "VAD data_to_pop is bigger than data_buf.size()!!!\n";
            data_to_pop = (int) data_buf.size();
            expected_sample_number = (int) data_buf.size();
            data_to_pop = data_buf_size;
            expected_sample_number = data_buf_size;
        }
        cur_seg.doa = 0;
        for (int sample_cpy_out = 0; sample_cpy_out < data_to_pop; sample_cpy_out++) {
@@ -619,7 +608,7 @@
        if (sil_pdf_ids.size() > 0) {
            std::vector<float> sil_pdf_scores;
            for (auto sil_pdf_id: sil_pdf_ids) {
                sil_pdf_scores.push_back(scores[t][sil_pdf_id]);
                sil_pdf_scores.push_back(scores[t - idx_pre_chunk][sil_pdf_id]);
            }
            sum_score = accumulate(sil_pdf_scores.begin(), sil_pdf_scores.end(), 0.0);
            noise_prob = log(sum_score) * vad_opts.speech_2_noise_ratio;
@@ -663,6 +652,7 @@
            frame_state = GetFrameState(frm_cnt - 1 - i);
            DetectOneFrame(frame_state, frm_cnt - 1 - i, false);
        }
        idx_pre_chunk += scores.size();
        return 0;
    }