游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/models/e2e_vad.py
@@ -5,6 +5,7 @@
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
from funasr.models.base_model import FunASRModel
class VadStateMachine(Enum):
@@ -211,7 +212,7 @@
        return int(self.frame_size_ms)
class E2EVadModel(nn.Module):
class E2EVadModel(FunASRModel):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
@@ -226,7 +227,6 @@
                                               self.vad_opts.frame_in_ms)
        self.encoder = encoder
        # init variables
        self.is_final = False
        self.data_buf_start_frame = 0
        self.frm_cnt = 0
        self.latest_confirmed_speech_frame = 0
@@ -253,11 +253,10 @@
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.ResetDetection()
        self.frontend = frontend
        self.last_drop_frames = 0
    def AllResetDetection(self):
        self.is_final = False
        self.data_buf_start_frame = 0
        self.frm_cnt = 0
        self.latest_confirmed_speech_frame = 0
@@ -284,7 +283,8 @@
        self.data_buf = None
        self.data_buf_all = None
        self.waveform = None
        self.ResetDetection()
        self.last_drop_frames = 0
        self.windows_detector.Reset()
    def ResetDetection(self):
        self.continous_silence_frame_count = 0
@@ -296,6 +296,15 @@
        self.windows_detector.Reset()
        self.sil_frame = 0
        self.frame_probs = []
        if self.output_data_buf:
            assert self.output_data_buf[-1].contain_seg_end_point == True
            drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
            real_drop_frames = drop_frames - self.last_drop_frames
            self.last_drop_frames = drop_frames
            self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
            self.decibel = self.decibel[real_drop_frames:]
            self.scores = self.scores[:, real_drop_frames:, :]
    def ComputeDecibel(self) -> None:
        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
@@ -324,7 +333,7 @@
        while self.data_buf_start_frame < frame_idx:
            if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
                self.data_buf_start_frame += 1
                self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
                self.data_buf = self.data_buf_all[(self.data_buf_start_frame - self.last_drop_frames) * int(
                    self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
    def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
@@ -473,6 +482,8 @@
    def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                is_final: bool = False
                ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
        if not in_cache:
            self.AllResetDetection()
        self.waveform = waveform  # compute decibel for each frame
        self.ComputeDecibel()
        self.ComputeScores(feats, in_cache)
@@ -501,6 +512,8 @@
    def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                       is_final: bool = False, max_end_sil: int = 800
                       ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
        if not in_cache:
            self.AllResetDetection()
        self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
        self.waveform = waveform  # compute decibel for each frame
@@ -541,7 +554,7 @@
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
            frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
            self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
        return 0
@@ -551,7 +564,7 @@
            return 0
        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
            frame_state = FrameState.kFrameStateInvalid
            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
            frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
            if i != 0:
                self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
            else: