aky15
2023-03-15 e33bb15d269bb3e2e41f7a3540d9b92703bb5c50
funasr/models/e2e_vad.py
@@ -215,6 +215,7 @@
        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.next_seg = True
        self.output_data_buf = []
        self.output_data_buf_offset = 0
@@ -244,6 +245,7 @@
        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
        self.noise_average_decibel = -100.0
        self.pre_end_silence_detected = False
        self.next_seg = True
        self.output_data_buf = []
        self.output_data_buf_offset = 0
@@ -441,7 +443,7 @@
                        - 1)) / self.vad_opts.noise_frame_num_used_for_snr
        return frame_state
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                is_final: bool = False
                ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
@@ -470,6 +472,42 @@
            self.AllResetDetection()
        return segments, in_cache
    def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                is_final: bool = False
                ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
        self.waveform = waveform  # compute decibel for each frame
        self.ComputeDecibel()
        self.ComputeScores(feats, in_cache)
        if not is_final:
            self.DetectCommonFrames()
        else:
            self.DetectLastFrames()
        segments = []
        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
            segment_batch = []
            if len(self.output_data_buf) > 0:
                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
                    if not self.output_data_buf[i].contain_seg_start_point:
                        continue
                    if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
                        continue
                    start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
                    if self.output_data_buf[i].contain_seg_end_point:
                        end_ms = self.output_data_buf[i].end_ms
                        self.next_seg = True
                        self.output_data_buf_offset += 1
                    else:
                        end_ms = -1
                        self.next_seg = False
                    segment = [start_ms, end_ms]
                    segment_batch.append(segment)
            if segment_batch:
                segments.append(segment_batch)
        if is_final:
            # reset class variables and clear the dict for the next query
            self.AllResetDetection()
        return segments, in_cache
    def DetectCommonFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0