old mode 100755
new mode 100644
| | |
| | | |
| | | |
| | | class E2EVadModel(nn.Module): |
| | | def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]): |
| | | def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None): |
| | | super(E2EVadModel, self).__init__() |
| | | self.vad_opts = VADXOptions(**vad_post_args) |
| | | self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, |
| | |
| | | self.data_buf_all = None |
| | | self.waveform = None |
| | | self.ResetDetection() |
| | | self.frontend = frontend |
| | | |
| | | def AllResetDetection(self): |
| | | self.is_final = False |
| | |
| | | return segments, in_cache |
| | | |
| | | def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(), |
| | | is_final: bool = False |
| | | is_final: bool = False, max_end_sil: int = 800 |
| | | ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: |
| | | self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres |
| | | self.waveform = waveform # compute decibel for each frame |
| | | self.ComputeDecibel() |
| | | |
| | | self.ComputeScores(feats, in_cache) |
| | | self.ComputeDecibel() |
| | | if not is_final: |
| | | self.DetectCommonFrames() |
| | | else: |