| | |
| | | |
| | | |
| | | class E2EVadModel(nn.Module): |
| | | def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], streaming=False): |
| | | def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]): |
| | | super(E2EVadModel, self).__init__() |
| | | self.vad_opts = VADXOptions(**vad_post_args) |
| | | self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, |
| | |
| | | self.data_buf = None |
| | | self.data_buf_all = None |
| | | self.waveform = None |
| | | self.streaming = streaming |
| | | self.ResetDetection() |
| | | |
| | | def AllResetDetection(self): |
| | |
| | | if not is_final_send: |
| | | self.DetectCommonFrames() |
| | | else: |
| | | if self.streaming: |
| | | self.DetectLastFrames() |
| | | else: |
| | | self.AllResetDetection() |
| | | self.DetectAllFrames() # offline decode and is_final_send == True |
| | | self.DetectLastFrames() |
| | | segments = [] |
| | | for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now |
| | | segment_batch = [] |
| | |
| | | self.output_data_buf_offset += 1 # need update this parameter |
| | | if segment_batch: |
| | | segments.append(segment_batch) |
| | | |
| | | if is_final_send: |
| | | self.AllResetDetection() |
| | | return segments |
| | | |
| | | def DetectCommonFrames(self) -> int: |
| | |
| | | else: |
| | | self.DetectOneFrame(frame_state, self.frm_cnt - 1, True) |
| | | |
| | | return 0 |
| | | |
| | | def DetectAllFrames(self) -> int: |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |
| | | return 0 |
| | | if self.vad_opts.nn_eval_block_size != self.vad_opts.dcd_block_size: |
| | | frame_state = FrameState.kFrameStateInvalid |
| | | for t in range(0, self.frm_cnt): |
| | | frame_state = self.GetFrameState(t) |
| | | self.DetectOneFrame(frame_state, t, t == self.frm_cnt - 1) |
| | | else: |
| | | pass |
| | | return 0 |
| | | |
| | | def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None: |