zhifu gao
2023-02-17 00f5ea6244384b2338b99984a55bf3f9e08dcc9c
funasr/models/e2e_vad.py
@@ -192,7 +192,7 @@
class E2EVadModel(nn.Module):
    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], streaming=False):
    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
        super(E2EVadModel, self).__init__()
        self.vad_opts = VADXOptions(**vad_post_args)
        self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@@ -227,7 +227,6 @@
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.streaming = streaming
        self.ResetDetection()
    def AllResetDetection(self):
@@ -451,11 +450,7 @@
        if not is_final_send:
            self.DetectCommonFrames()
        else:
            if self.streaming:
                self.DetectLastFrames()
            else:
                self.AllResetDetection()
                self.DetectAllFrames()  # offline decode and is_final_send == True
            self.DetectLastFrames()
        segments = []
        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
            segment_batch = []
@@ -468,7 +463,8 @@
                        self.output_data_buf_offset += 1  # need update this parameter
            if segment_batch:
                segments.append(segment_batch)
        if is_final_send:
            self.AllResetDetection()
        return segments
    def DetectCommonFrames(self) -> int:
@@ -492,18 +488,6 @@
            else:
                self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
        return 0
    def DetectAllFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
        if self.vad_opts.nn_eval_block_size != self.vad_opts.dcd_block_size:
            frame_state = FrameState.kFrameStateInvalid
            for t in range(0, self.frm_cnt):
                frame_state = self.GetFrameState(t)
                self.DetectOneFrame(frame_state, t, t == self.frm_cnt - 1)
        else:
            pass
        return 0
    def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None: