游雁
2024-01-12 cf2f14345aa2c4f168ee51c200b8081c748980b8
funasr/models/fsmn_vad/model.py
@@ -333,8 +333,8 @@
                10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                0.000001))
    def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
        scores = self.encoder(feats, in_cache).to('cpu')  # return B * T * D
    def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None:
        scores = self.encoder(feats, cache).to('cpu')  # return B * T * D
        assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
        self.vad_opts.nn_eval_block_size = scores.shape[1]
        self.frm_cnt += scores.shape[1]  # count total frames
@@ -493,14 +493,14 @@
        return frame_state
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
                is_final: bool = False
                ):
        if not in_cache:
        if not cache:
            self.AllResetDetection()
        self.waveform = waveform  # compute decibel for each frame
        self.ComputeDecibel()
        self.ComputeScores(feats, in_cache)
        self.ComputeScores(feats, cache)
        if not is_final:
            self.DetectCommonFrames()
        else:
@@ -521,7 +521,7 @@
        if is_final:
            # reset class variables and clear the dict for the next query
            self.AllResetDetection()
        return segments, in_cache
        return segments, cache
    def generate(self,
                 data_in,
@@ -561,7 +561,7 @@
        feats = speech
        feats_len = speech_lengths.max().item()
        waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
        in_cache = kwargs.get("in_cache", {})
        cache = kwargs.get("cache", {})
        batch_size = kwargs.get("batch_size", 1)
        step = min(feats_len, 6000)
        segments = [[]] * batch_size
@@ -576,11 +576,11 @@
                "feats": feats[:, t_offset:t_offset + step, :],
                "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
                "is_final": is_final,
                "in_cache": in_cache
                "cache": cache
            }
            segments_part, in_cache = self.forward(**batch)
            segments_part, cache = self.forward(**batch)
            if segments_part:
                for batch_num in range(0, batch_size):
                    segments[batch_num] += segments_part[batch_num]
@@ -603,46 +603,6 @@
            results.append(result_i)
 
        return results, meta_data
    def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                       is_final: bool = False, max_end_sil: int = 800
                       ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
        if not in_cache:
            self.AllResetDetection()
        self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
        self.waveform = waveform  # compute decibel for each frame
        self.ComputeScores(feats, in_cache)
        self.ComputeDecibel()
        if not is_final:
            self.DetectCommonFrames()
        else:
            self.DetectLastFrames()
        segments = []
        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
            segment_batch = []
            if len(self.output_data_buf) > 0:
                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
                    if not self.output_data_buf[i].contain_seg_start_point:
                        continue
                    if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
                        continue
                    start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
                    if self.output_data_buf[i].contain_seg_end_point:
                        end_ms = self.output_data_buf[i].end_ms
                        self.next_seg = True
                        self.output_data_buf_offset += 1
                    else:
                        end_ms = -1
                        self.next_seg = False
                    segment = [start_ms, end_ms]
                    segment_batch.append(segment)
            if segment_batch:
                segments.append(segment_batch)
        if is_final:
            # reset class variables and clear the dict for the next query
            self.AllResetDetection()
        return segments, in_cache
    def DetectCommonFrames(self) -> int:
        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: