From b28f3c9da94ae72a3a0b7bb5982b587be7cf4cd6 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 18 一月 2024 22:00:58 +0800
Subject: [PATCH] fsmn-vad bugfix (#1270)

---
 funasr/models/fsmn_vad_streaming/model.py | 1374 +++++++++++++++++++++++++++++-----------------------------
 1 files changed, 689 insertions(+), 685 deletions(-)

diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 193feb0..943cb47 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -19,714 +19,718 @@
 
 
 class VadStateMachine(Enum):
-    kVadInStateStartPointNotDetected = 1
-    kVadInStateInSpeechSegment = 2
-    kVadInStateEndPointDetected = 3
+	kVadInStateStartPointNotDetected = 1
+	kVadInStateInSpeechSegment = 2
+	kVadInStateEndPointDetected = 3
 
 class FrameState(Enum):
-    kFrameStateInvalid = -1
-    kFrameStateSpeech = 1
-    kFrameStateSil = 0
+	kFrameStateInvalid = -1
+	kFrameStateSpeech = 1
+	kFrameStateSil = 0
 
 # final voice/unvoice state per frame
 class AudioChangeState(Enum):
-    kChangeStateSpeech2Speech = 0
-    kChangeStateSpeech2Sil = 1
-    kChangeStateSil2Sil = 2
-    kChangeStateSil2Speech = 3
-    kChangeStateNoBegin = 4
-    kChangeStateInvalid = 5
+	kChangeStateSpeech2Speech = 0
+	kChangeStateSpeech2Sil = 1
+	kChangeStateSil2Sil = 2
+	kChangeStateSil2Speech = 3
+	kChangeStateNoBegin = 4
+	kChangeStateInvalid = 5
 
 class VadDetectMode(Enum):
-    kVadSingleUtteranceDetectMode = 0
-    kVadMutipleUtteranceDetectMode = 1
+	kVadSingleUtteranceDetectMode = 0
+	kVadMutipleUtteranceDetectMode = 1
 
 class VADXOptions:
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(
-            self,
-            sample_rate: int = 16000,
-            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
-            snr_mode: int = 0,
-            max_end_silence_time: int = 800,
-            max_start_silence_time: int = 3000,
-            do_start_point_detection: bool = True,
-            do_end_point_detection: bool = True,
-            window_size_ms: int = 200,
-            sil_to_speech_time_thres: int = 150,
-            speech_to_sil_time_thres: int = 150,
-            speech_2_noise_ratio: float = 1.0,
-            do_extend: int = 1,
-            lookback_time_start_point: int = 200,
-            lookahead_time_end_point: int = 100,
-            max_single_segment_time: int = 60000,
-            nn_eval_block_size: int = 8,
-            dcd_block_size: int = 4,
-            snr_thres: int = -100.0,
-            noise_frame_num_used_for_snr: int = 100,
-            decibel_thres: int = -100.0,
-            speech_noise_thres: float = 0.6,
-            fe_prior_thres: float = 1e-4,
-            silence_pdf_num: int = 1,
-            sil_pdf_ids: List[int] = [0],
-            speech_noise_thresh_low: float = -0.1,
-            speech_noise_thresh_high: float = 0.3,
-            output_frame_probs: bool = False,
-            frame_in_ms: int = 10,
-            frame_length_ms: int = 25,
-            **kwargs,
-    ):
-        self.sample_rate = sample_rate
-        self.detect_mode = detect_mode
-        self.snr_mode = snr_mode
-        self.max_end_silence_time = max_end_silence_time
-        self.max_start_silence_time = max_start_silence_time
-        self.do_start_point_detection = do_start_point_detection
-        self.do_end_point_detection = do_end_point_detection
-        self.window_size_ms = window_size_ms
-        self.sil_to_speech_time_thres = sil_to_speech_time_thres
-        self.speech_to_sil_time_thres = speech_to_sil_time_thres
-        self.speech_2_noise_ratio = speech_2_noise_ratio
-        self.do_extend = do_extend
-        self.lookback_time_start_point = lookback_time_start_point
-        self.lookahead_time_end_point = lookahead_time_end_point
-        self.max_single_segment_time = max_single_segment_time
-        self.nn_eval_block_size = nn_eval_block_size
-        self.dcd_block_size = dcd_block_size
-        self.snr_thres = snr_thres
-        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
-        self.decibel_thres = decibel_thres
-        self.speech_noise_thres = speech_noise_thres
-        self.fe_prior_thres = fe_prior_thres
-        self.silence_pdf_num = silence_pdf_num
-        self.sil_pdf_ids = sil_pdf_ids
-        self.speech_noise_thresh_low = speech_noise_thresh_low
-        self.speech_noise_thresh_high = speech_noise_thresh_high
-        self.output_frame_probs = output_frame_probs
-        self.frame_in_ms = frame_in_ms
-        self.frame_length_ms = frame_length_ms
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
+	def __init__(
+		self,
+		sample_rate: int = 16000,
+		detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
+		snr_mode: int = 0,
+		max_end_silence_time: int = 800,
+		max_start_silence_time: int = 3000,
+		do_start_point_detection: bool = True,
+		do_end_point_detection: bool = True,
+		window_size_ms: int = 200,
+		sil_to_speech_time_thres: int = 150,
+		speech_to_sil_time_thres: int = 150,
+		speech_2_noise_ratio: float = 1.0,
+		do_extend: int = 1,
+		lookback_time_start_point: int = 200,
+		lookahead_time_end_point: int = 100,
+		max_single_segment_time: int = 60000,
+		nn_eval_block_size: int = 8,
+		dcd_block_size: int = 4,
+		snr_thres: int = -100.0,
+		noise_frame_num_used_for_snr: int = 100,
+		decibel_thres: int = -100.0,
+		speech_noise_thres: float = 0.6,
+		fe_prior_thres: float = 1e-4,
+		silence_pdf_num: int = 1,
+		sil_pdf_ids: List[int] = [0],
+		speech_noise_thresh_low: float = -0.1,
+		speech_noise_thresh_high: float = 0.3,
+		output_frame_probs: bool = False,
+		frame_in_ms: int = 10,
+		frame_length_ms: int = 25,
+		**kwargs,
+	):
+		self.sample_rate = sample_rate
+		self.detect_mode = detect_mode
+		self.snr_mode = snr_mode
+		self.max_end_silence_time = max_end_silence_time
+		self.max_start_silence_time = max_start_silence_time
+		self.do_start_point_detection = do_start_point_detection
+		self.do_end_point_detection = do_end_point_detection
+		self.window_size_ms = window_size_ms
+		self.sil_to_speech_time_thres = sil_to_speech_time_thres
+		self.speech_to_sil_time_thres = speech_to_sil_time_thres
+		self.speech_2_noise_ratio = speech_2_noise_ratio
+		self.do_extend = do_extend
+		self.lookback_time_start_point = lookback_time_start_point
+		self.lookahead_time_end_point = lookahead_time_end_point
+		self.max_single_segment_time = max_single_segment_time
+		self.nn_eval_block_size = nn_eval_block_size
+		self.dcd_block_size = dcd_block_size
+		self.snr_thres = snr_thres
+		self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
+		self.decibel_thres = decibel_thres
+		self.speech_noise_thres = speech_noise_thres
+		self.fe_prior_thres = fe_prior_thres
+		self.silence_pdf_num = silence_pdf_num
+		self.sil_pdf_ids = sil_pdf_ids
+		self.speech_noise_thresh_low = speech_noise_thresh_low
+		self.speech_noise_thresh_high = speech_noise_thresh_high
+		self.output_frame_probs = output_frame_probs
+		self.frame_in_ms = frame_in_ms
+		self.frame_length_ms = frame_length_ms
 
 
 class E2EVadSpeechBufWithDoa(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self):
-        self.start_ms = 0
-        self.end_ms = 0
-        self.buffer = []
-        self.contain_seg_start_point = False
-        self.contain_seg_end_point = False
-        self.doa = 0
-
-    def Reset(self):
-        self.start_ms = 0
-        self.end_ms = 0
-        self.buffer = []
-        self.contain_seg_start_point = False
-        self.contain_seg_end_point = False
-        self.doa = 0
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
+	def __init__(self):
+		self.start_ms = 0
+		self.end_ms = 0
+		self.buffer = []
+		self.contain_seg_start_point = False
+		self.contain_seg_end_point = False
+		self.doa = 0
+	
+	def Reset(self):
+		self.start_ms = 0
+		self.end_ms = 0
+		self.buffer = []
+		self.contain_seg_start_point = False
+		self.contain_seg_end_point = False
+		self.doa = 0
 
 
 class E2EVadFrameProb(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self):
-        self.noise_prob = 0.0
-        self.speech_prob = 0.0
-        self.score = 0.0
-        self.frame_id = 0
-        self.frm_state = 0
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
+	def __init__(self):
+		self.noise_prob = 0.0
+		self.speech_prob = 0.0
+		self.score = 0.0
+		self.frame_id = 0
+		self.frm_state = 0
 
 
 class WindowDetector(object):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self, window_size_ms: int,
-                 sil_to_speech_time: int,
-                 speech_to_sil_time: int,
-                 frame_size_ms: int):
-        self.window_size_ms = window_size_ms
-        self.sil_to_speech_time = sil_to_speech_time
-        self.speech_to_sil_time = speech_to_sil_time
-        self.frame_size_ms = frame_size_ms
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
+	def __init__(self, window_size_ms: int,
+	             sil_to_speech_time: int,
+	             speech_to_sil_time: int,
+	             frame_size_ms: int):
+		self.window_size_ms = window_size_ms
+		self.sil_to_speech_time = sil_to_speech_time
+		self.speech_to_sil_time = speech_to_sil_time
+		self.frame_size_ms = frame_size_ms
+		
+		self.win_size_frame = int(window_size_ms / frame_size_ms)
+		self.win_sum = 0
+		self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
+		
+		self.cur_win_pos = 0
+		self.pre_frame_state = FrameState.kFrameStateSil
+		self.cur_frame_state = FrameState.kFrameStateSil
+		self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
+		self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
+		
+		self.voice_last_frame_count = 0
+		self.noise_last_frame_count = 0
+		self.hydre_frame_count = 0
+	
+	def Reset(self) -> None:
+		self.cur_win_pos = 0
+		self.win_sum = 0
+		self.win_state = [0] * self.win_size_frame
+		self.pre_frame_state = FrameState.kFrameStateSil
+		self.cur_frame_state = FrameState.kFrameStateSil
+		self.voice_last_frame_count = 0
+		self.noise_last_frame_count = 0
+		self.hydre_frame_count = 0
+	
+	def GetWinSize(self) -> int:
+		return int(self.win_size_frame)
+	
+	def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
+		cur_frame_state = FrameState.kFrameStateSil
+		if frameState == FrameState.kFrameStateSpeech:
+			cur_frame_state = 1
+		elif frameState == FrameState.kFrameStateSil:
+			cur_frame_state = 0
+		else:
+			return AudioChangeState.kChangeStateInvalid
+		self.win_sum -= self.win_state[self.cur_win_pos]
+		self.win_sum += cur_frame_state
+		self.win_state[self.cur_win_pos] = cur_frame_state
+		self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
+		
+		if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
+			self.pre_frame_state = FrameState.kFrameStateSpeech
+			return AudioChangeState.kChangeStateSil2Speech
+		
+		if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
+			self.pre_frame_state = FrameState.kFrameStateSil
+			return AudioChangeState.kChangeStateSpeech2Sil
+		
+		if self.pre_frame_state == FrameState.kFrameStateSil:
+			return AudioChangeState.kChangeStateSil2Sil
+		if self.pre_frame_state == FrameState.kFrameStateSpeech:
+			return AudioChangeState.kChangeStateSpeech2Speech
+		return AudioChangeState.kChangeStateInvalid
+	
+	def FrameSizeMs(self) -> int:
+		return int(self.frame_size_ms)
 
-        self.win_size_frame = int(window_size_ms / frame_size_ms)
-        self.win_sum = 0
-        self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
-
-        self.cur_win_pos = 0
-        self.pre_frame_state = FrameState.kFrameStateSil
-        self.cur_frame_state = FrameState.kFrameStateSil
-        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
-        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
-
-        self.voice_last_frame_count = 0
-        self.noise_last_frame_count = 0
-        self.hydre_frame_count = 0
-
-    def Reset(self) -> None:
-        self.cur_win_pos = 0
-        self.win_sum = 0
-        self.win_state = [0] * self.win_size_frame
-        self.pre_frame_state = FrameState.kFrameStateSil
-        self.cur_frame_state = FrameState.kFrameStateSil
-        self.voice_last_frame_count = 0
-        self.noise_last_frame_count = 0
-        self.hydre_frame_count = 0
-
-    def GetWinSize(self) -> int:
-        return int(self.win_size_frame)
-
-    def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
-        cur_frame_state = FrameState.kFrameStateSil
-        if frameState == FrameState.kFrameStateSpeech:
-            cur_frame_state = 1
-        elif frameState == FrameState.kFrameStateSil:
-            cur_frame_state = 0
-        else:
-            return AudioChangeState.kChangeStateInvalid
-        self.win_sum -= self.win_state[self.cur_win_pos]
-        self.win_sum += cur_frame_state
-        self.win_state[self.cur_win_pos] = cur_frame_state
-        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
-
-        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
-            self.pre_frame_state = FrameState.kFrameStateSpeech
-            return AudioChangeState.kChangeStateSil2Speech
-
-        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
-            self.pre_frame_state = FrameState.kFrameStateSil
-            return AudioChangeState.kChangeStateSpeech2Sil
-
-        if self.pre_frame_state == FrameState.kFrameStateSil:
-            return AudioChangeState.kChangeStateSil2Sil
-        if self.pre_frame_state == FrameState.kFrameStateSpeech:
-            return AudioChangeState.kChangeStateSpeech2Speech
-        return AudioChangeState.kChangeStateInvalid
-
-    def FrameSizeMs(self) -> int:
-        return int(self.frame_size_ms)
+class Stats(object):
+	def __init__(self,
+	             sil_pdf_ids,
+	             max_end_sil_frame_cnt_thresh,
+	             speech_noise_thres,
+	             ):
+		
+		self.data_buf_start_frame = 0
+		self.frm_cnt = 0
+		self.latest_confirmed_speech_frame = 0
+		self.lastest_confirmed_silence_frame = -1
+		self.continous_silence_frame_count = 0
+		self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+		self.confirmed_start_frame = -1
+		self.confirmed_end_frame = -1
+		self.number_end_time_detected = 0
+		self.sil_frame = 0
+		self.sil_pdf_ids = sil_pdf_ids
+		self.noise_average_decibel = -100.0
+		self.pre_end_silence_detected = False
+		self.next_seg = True
+		
+		self.output_data_buf = []
+		self.output_data_buf_offset = 0
+		self.frame_probs = []
+		self.max_end_sil_frame_cnt_thresh = max_end_sil_frame_cnt_thresh
+		self.speech_noise_thres = speech_noise_thres
+		self.scores = None
+		self.max_time_out = False
+		self.decibel = []
+		self.data_buf = None
+		self.data_buf_all = None
+		self.waveform = None
+		self.last_drop_frames = 0
 
 
-@dataclass
-class StatsItem:
-    
-    # init variables
-    data_buf_start_frame = 0
-    frm_cnt = 0
-    latest_confirmed_speech_frame = 0
-    lastest_confirmed_silence_frame = -1
-    continous_silence_frame_count = 0
-    vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
-    confirmed_start_frame = -1
-    confirmed_end_frame = -1
-    number_end_time_detected = 0
-    sil_frame = 0
-    sil_pdf_ids: list
-    noise_average_decibel = -100.0
-    pre_end_silence_detected = False
-    next_seg = True # unused
-    
-    output_data_buf = []
-    output_data_buf_offset = 0
-    frame_probs = [] # unused
-    max_end_sil_frame_cnt_thresh: int
-    speech_noise_thres: float
-    scores = None
-    max_time_out = False #unused
-    decibel = []
-    data_buf = None
-    data_buf_all = None
-    waveform = None
-    last_drop_frames = 0
-    
 @tables.register("model_classes", "FsmnVADStreaming")
 class FsmnVADStreaming(nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-    https://arxiv.org/abs/1803.05030
-    """
-    def __init__(self,
-                 encoder: str = None,
-                 encoder_conf: Optional[Dict] = None,
-                 vad_post_args: Dict[str, Any] = None,
-                 **kwargs,
-                 ):
-        super().__init__()
-        self.vad_opts = VADXOptions(**kwargs)
-
-        encoder_class = tables.encoder_classes.get(encoder)
-        encoder = encoder_class(**encoder_conf)
-        self.encoder = encoder
-
-
-    def ResetDetection(self, cache: dict = {}):
-        cache["stats"].continous_silence_frame_count = 0
-        cache["stats"].latest_confirmed_speech_frame = 0
-        cache["stats"].lastest_confirmed_silence_frame = -1
-        cache["stats"].confirmed_start_frame = -1
-        cache["stats"].confirmed_end_frame = -1
-        cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
-        cache["windows_detector"].Reset()
-        cache["stats"].sil_frame = 0
-        cache["stats"].frame_probs = []
-
-        if cache["stats"].output_data_buf:
-            assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
-            drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
-            real_drop_frames = drop_frames - cache["stats"].last_drop_frames
-            cache["stats"].last_drop_frames = drop_frames
-            cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
-            cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
-            cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
-
-    def ComputeDecibel(self, cache: dict = {}) -> None:
-        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
-        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
-        if cache["stats"].data_buf_all is None:
-            cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
-            cache["stats"].data_buf = cache["stats"].data_buf_all
-        else:
-            cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
-        for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
-            cache["stats"].decibel.append(
-                10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \
-                                0.000001))
-
-    def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
-        scores = self.encoder(feats, cache=cache["encoder"]).to('cpu')  # return B * T * D
-        assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
-        self.vad_opts.nn_eval_block_size = scores.shape[1]
-        cache["stats"].frm_cnt += scores.shape[1]  # count total frames
-        if cache["stats"].scores is None:
-            cache["stats"].scores = scores  # the first calculation
-        else:
-            cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
-
-    def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
-        while cache["stats"].data_buf_start_frame < frame_idx:
-            if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
-                cache["stats"].data_buf_start_frame += 1
-                cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
-                    self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
-
-    def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
-                           last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
-        self.PopDataBufTillFrame(start_frm, cache=cache)
-        expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
-        if last_frm_is_end_point:
-            extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
-                                      self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
-            expected_sample_number += int(extra_sample)
-        if end_point_is_sent_end:
-            expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
-        if len(cache["stats"].data_buf) < expected_sample_number:
-            print('error in calling pop data_buf\n')
-
-        if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
-            cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
-            cache["stats"].output_data_buf[-1].Reset()
-            cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
-            cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
-            cache["stats"].output_data_buf[-1].doa = 0
-        cur_seg = cache["stats"].output_data_buf[-1]
-        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
-            print('warning\n')
-        out_pos = len(cur_seg.buffer)  # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
-        data_to_pop = 0
-        if end_point_is_sent_end:
-            data_to_pop = expected_sample_number
-        else:
-            data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
-        if data_to_pop > len(cache["stats"].data_buf):
-            print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
-            data_to_pop = len(cache["stats"].data_buf)
-            expected_sample_number = len(cache["stats"].data_buf)
-
-        cur_seg.doa = 0
-        for sample_cpy_out in range(0, data_to_pop):
-            # cur_seg.buffer[out_pos ++] = data_buf_.back();
-            out_pos += 1
-        for sample_cpy_out in range(data_to_pop, expected_sample_number):
-            # cur_seg.buffer[out_pos++] = data_buf_.back()
-            out_pos += 1
-        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
-            print('Something wrong with the VAD algorithm\n')
-        cache["stats"].data_buf_start_frame += frm_cnt
-        cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
-        if first_frm_is_start_point:
-            cur_seg.contain_seg_start_point = True
-        if last_frm_is_end_point:
-            cur_seg.contain_seg_end_point = True
-
-    def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
-        cache["stats"].lastest_confirmed_silence_frame = valid_frame
-        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-            self.PopDataBufTillFrame(valid_frame, cache=cache)
-        # silence_detected_callback_
-        # pass
-
-    def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
-        cache["stats"].latest_confirmed_speech_frame = valid_frame
-        self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
-
-    def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
-        if self.vad_opts.do_start_point_detection:
-            pass
-        if cache["stats"].confirmed_start_frame != -1:
-            print('not reset vad properly\n')
-        else:
-            cache["stats"].confirmed_start_frame = start_frame
-
-        if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-            self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
-
-    def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
-        for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
-            self.OnVoiceDetected(t, cache=cache)
-        if self.vad_opts.do_end_point_detection:
-            pass
-        if cache["stats"].confirmed_end_frame != -1:
-            print('not reset vad properly\n')
-        else:
-            cache["stats"].confirmed_end_frame = end_frame
-        if not fake_result:
-            cache["stats"].sil_frame = 0
-            self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache)
-        cache["stats"].number_end_time_detected += 1
-
-    def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None:
-        if is_final_frame:
-            self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
-            cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-
-    def GetLatency(self, cache: dict = {}) -> int:
-        return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
-
-    def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
-        vad_latency = cache["windows_detector"].GetWinSize()
-        if self.vad_opts.do_extend:
-            vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
-        return vad_latency
-
-    def GetFrameState(self, t: int, cache: dict = {}):
-        frame_state = FrameState.kFrameStateInvalid
-        cur_decibel = cache["stats"].decibel[t]
-        cur_snr = cur_decibel - cache["stats"].noise_average_decibel
-        # for each frame, calc log posterior probability of each state
-        if cur_decibel < self.vad_opts.decibel_thres:
-            frame_state = FrameState.kFrameStateSil
-            self.DetectOneFrame(frame_state, t, False, cache=cache)
-            return frame_state
-
-        sum_score = 0.0
-        noise_prob = 0.0
-        assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
-        if len(cache["stats"].sil_pdf_ids) > 0:
-            assert len(cache["stats"].scores) == 1  # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
-            sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids]
-            sum_score = sum(sil_pdf_scores)
-            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
-            total_score = 1.0
-            sum_score = total_score - sum_score
-        speech_prob = math.log(sum_score)
-        if self.vad_opts.output_frame_probs:
-            frame_prob = E2EVadFrameProb()
-            frame_prob.noise_prob = noise_prob
-            frame_prob.speech_prob = speech_prob
-            frame_prob.score = sum_score
-            frame_prob.frame_id = t
-            cache["stats"].frame_probs.append(frame_prob)
-        if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
-            if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
-                frame_state = FrameState.kFrameStateSpeech
-            else:
-                frame_state = FrameState.kFrameStateSil
-        else:
-            frame_state = FrameState.kFrameStateSil
-            if cache["stats"].noise_average_decibel < -99.9:
-                cache["stats"].noise_average_decibel = cur_decibel
-            else:
-                cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * (
-                        self.vad_opts.noise_frame_num_used_for_snr
-                        - 1)) / self.vad_opts.noise_frame_num_used_for_snr
-
-        return frame_state
-
-    def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
-                is_final: bool = False
-                ):
-        # if len(cache) == 0:
-        #     self.AllResetDetection()
-        # self.waveform = waveform  # compute decibel for each frame
-        cache["stats"].waveform = waveform
-        self.ComputeDecibel(cache=cache)
-        self.ComputeScores(feats, cache=cache)
-        if not is_final:
-            self.DetectCommonFrames(cache=cache)
-        else:
-            self.DetectLastFrames(cache=cache)
-        segments = []
-        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
-            segment_batch = []
-            if len(cache["stats"].output_data_buf) > 0:
-                for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
-                    if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
-                        i].contain_seg_end_point):
-                        continue
-                    segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
-                    segment_batch.append(segment)
-                    cache["stats"].output_data_buf_offset += 1  # need update this parameter
-            if segment_batch:
-                segments.append(segment_batch)
-        # if is_final:
-        #     # reset class variables and clear the dict for the next query
-        #     self.AllResetDetection()
-        return segments
-
-    def init_cache(self, cache: dict = {}, **kwargs):
-        cache["frontend"] = {}
-        cache["prev_samples"] = torch.empty(0)
-        cache["encoder"] = {}
-        windows_detector = WindowDetector(self.vad_opts.window_size_ms,
-                                          self.vad_opts.sil_to_speech_time_thres,
-                                          self.vad_opts.speech_to_sil_time_thres,
-                                          self.vad_opts.frame_in_ms)
-
-        stats = StatsItem(sil_pdf_ids=self.vad_opts.sil_pdf_ids,
-                          max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres,
-                          speech_noise_thres=self.vad_opts.speech_noise_thres,
-                      )
-        cache["windows_detector"] = windows_detector
-        cache["stats"] = stats
-        return cache
-    
-    def inference(self,
-                 data_in,
-                 data_lengths=None,
-                 key: list = None,
-                 tokenizer=None,
-                 frontend=None,
-                 cache: dict = {},
-                 **kwargs,
-                 ):
-    
-        if len(cache) == 0:
-            self.init_cache(cache, **kwargs)
-
-        meta_data = {}
-        chunk_size = kwargs.get("chunk_size", 60000) # 50ms
-        chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
-
-        time1 = time.perf_counter()
-        cfg = {"is_final": kwargs.get("is_final", False)}
-        audio_sample_list = load_audio_text_image_video(data_in,
-                                                        fs=frontend.fs,
-                                                        audio_fs=kwargs.get("fs", 16000),
-                                                        data_type=kwargs.get("data_type", "sound"),
-                                                        tokenizer=tokenizer,
-                                                        cache=cfg,
-                                                        )
-        _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
-
-        time2 = time.perf_counter()
-        meta_data["load_data"] = f"{time2 - time1:0.3f}"
-        assert len(audio_sample_list) == 1, "batch_size must be set 1"
-
-        audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
-
-        n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
-        m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
-        segments = []
-        for i in range(n):
-            kwargs["is_final"] = _is_final and i == n - 1
-            audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
-    
-            # extract fbank feats
-            speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
-                                                   frontend=frontend, cache=cache["frontend"],
-                                                   is_final=kwargs["is_final"])
-            time3 = time.perf_counter()
-            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-            speech = speech.to(device=kwargs["device"])
-            speech_lengths = speech_lengths.to(device=kwargs["device"])
-            
-            batch = {
-                "feats": speech,
-                "waveform": cache["frontend"]["waveforms"],
-                "is_final": kwargs["is_final"],
-                "cache": cache
-            }
-            segments_i = self.forward(**batch)
-            if len(segments_i) > 0:
-                segments.extend(*segments_i)
-
-
-        cache["prev_samples"] = audio_sample[:-m]
-        if _is_final:
-            self.init_cache(cache, **kwargs)
-
-        ibest_writer = None
-        if ibest_writer is None and kwargs.get("output_dir") is not None:
-            writer = DatadirWriter(kwargs.get("output_dir"))
-            ibest_writer = writer[f"{1}best_recog"]
-
-        results = []
-        result_i = {"key": key[0], "value": segments}
-        if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
-            result_i = json.dumps(result_i)
-
-        results.append(result_i)
-            
-        if ibest_writer is not None:
-            ibest_writer["text"][key[0]] = segments
-
- 
-        return results, meta_data
-
-
-    def DetectCommonFrames(self, cache: dict = {}) -> int:
-        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
-            return 0
-        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
-            frame_state = FrameState.kFrameStateInvalid
-            frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
-            self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
-
-        return 0
-
-    def DetectLastFrames(self, cache: dict = {}) -> int:
-        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
-            return 0
-        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
-            frame_state = FrameState.kFrameStateInvalid
-            frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
-            if i != 0:
-                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
-            else:
-                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
-
-        return 0
-
-    def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
-        tmp_cur_frm_state = FrameState.kFrameStateInvalid
-        if cur_frm_state == FrameState.kFrameStateSpeech:
-            if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
-                tmp_cur_frm_state = FrameState.kFrameStateSpeech
-            else:
-                tmp_cur_frm_state = FrameState.kFrameStateSil
-        elif cur_frm_state == FrameState.kFrameStateSil:
-            tmp_cur_frm_state = FrameState.kFrameStateSil
-        state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache)
-        frm_shift_in_ms = self.vad_opts.frame_in_ms
-        if AudioChangeState.kChangeStateSil2Speech == state_change:
-            silence_frame_count = cache["stats"].continous_silence_frame_count
-            cache["stats"].continous_silence_frame_count = 0
-            cache["stats"].pre_end_silence_detected = False
-            start_frame = 0
-            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-                start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
-                self.OnVoiceStart(start_frame, cache=cache)
-                cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
-                for t in range(start_frame + 1, cur_frm_idx + 1):
-                    self.OnVoiceDetected(t, cache=cache)
-            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-                for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
-                    self.OnVoiceDetected(t, cache=cache)
-                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                elif not is_final_frame:
-                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
-                else:
-                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-            else:
-                pass
-        elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
-            cache["stats"].continous_silence_frame_count = 0
-            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-                pass
-            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                elif not is_final_frame:
-                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
-                else:
-                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-            else:
-                pass
-        elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
-            cache["stats"].continous_silence_frame_count = 0
-            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-                if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-                    cache["stats"].max_time_out = True
-                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                elif not is_final_frame:
-                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
-                else:
-                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-            else:
-                pass
-        elif AudioChangeState.kChangeStateSil2Sil == state_change:
-            cache["stats"].continous_silence_frame_count += 1
-            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-                # silence timeout, return zero length decision
-                if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
-                        cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
-                        or (is_final_frame and cache["stats"].number_end_time_detected == 0):
-                    for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
-                        self.OnSilenceDetected(t, cache=cache)
-                    self.OnVoiceStart(0, True, cache=cache)
-                    self.OnVoiceEnd(0, True, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                else:
-                    if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
-                        self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
-            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-                if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
-                    lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
-                    if self.vad_opts.do_extend:
-                        lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
-                        lookback_frame -= 1
-                        lookback_frame = max(0, lookback_frame)
-                    self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-                elif self.vad_opts.do_extend and not is_final_frame:
-                    if cache["stats"].continous_silence_frame_count <= int(
-                            self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
-                        self.OnVoiceDetected(cur_frm_idx, cache=cache)
-                else:
-                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-            else:
-                pass
-
-        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
-                self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
-            self.ResetDetection(cache=cache)
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+	https://arxiv.org/abs/1803.05030
+	"""
+	def __init__(self,
+	             encoder: str = None,
+	             encoder_conf: Optional[Dict] = None,
+	             vad_post_args: Dict[str, Any] = None,
+	             **kwargs,
+	             ):
+		super().__init__()
+		self.vad_opts = VADXOptions(**kwargs)
+		
+		encoder_class = tables.encoder_classes.get(encoder)
+		encoder = encoder_class(**encoder_conf)
+		self.encoder = encoder
+	
+	
+	def ResetDetection(self, cache: dict = {}):
+		cache["stats"].continous_silence_frame_count = 0
+		cache["stats"].latest_confirmed_speech_frame = 0
+		cache["stats"].lastest_confirmed_silence_frame = -1
+		cache["stats"].confirmed_start_frame = -1
+		cache["stats"].confirmed_end_frame = -1
+		cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+		cache["windows_detector"].Reset()
+		cache["stats"].sil_frame = 0
+		cache["stats"].frame_probs = []
+		
+		if cache["stats"].output_data_buf:
+			assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
+			drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
+			real_drop_frames = drop_frames - cache["stats"].last_drop_frames
+			cache["stats"].last_drop_frames = drop_frames
+			cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+			cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
+			cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
+	
+	def ComputeDecibel(self, cache: dict = {}) -> None:
+		frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
+		frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+		if cache["stats"].data_buf_all is None:
+			cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
+			cache["stats"].data_buf = cache["stats"].data_buf_all
+		else:
+			cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
+		for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
+			cache["stats"].decibel.append(
+				10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \
+				                0.000001))
+	
+	def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
+		scores = self.encoder(feats, cache=cache["encoder"]).to('cpu')  # return B * T * D
+		assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
+		self.vad_opts.nn_eval_block_size = scores.shape[1]
+		cache["stats"].frm_cnt += scores.shape[1]  # count total frames
+		if cache["stats"].scores is None:
+			cache["stats"].scores = scores  # the first calculation
+		else:
+			cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
+	
+	def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
+		while cache["stats"].data_buf_start_frame < frame_idx:
+			if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
+				cache["stats"].data_buf_start_frame += 1
+				cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
+					self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+	
+	def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
+	                       last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
+		self.PopDataBufTillFrame(start_frm, cache=cache)
+		expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
+		if last_frm_is_end_point:
+			extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
+			                          self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
+			expected_sample_number += int(extra_sample)
+		if end_point_is_sent_end:
+			expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
+		if len(cache["stats"].data_buf) < expected_sample_number:
+			print('error in calling pop data_buf\n')
+		
+		if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
+			cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
+			cache["stats"].output_data_buf[-1].Reset()
+			cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
+			cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
+			cache["stats"].output_data_buf[-1].doa = 0
+		cur_seg = cache["stats"].output_data_buf[-1]
+		if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+			print('warning\n')
+		out_pos = len(cur_seg.buffer)  # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
+		data_to_pop = 0
+		if end_point_is_sent_end:
+			data_to_pop = expected_sample_number
+		else:
+			data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+		if data_to_pop > len(cache["stats"].data_buf):
+			print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
+			data_to_pop = len(cache["stats"].data_buf)
+			expected_sample_number = len(cache["stats"].data_buf)
+		
+		cur_seg.doa = 0
+		for sample_cpy_out in range(0, data_to_pop):
+			# cur_seg.buffer[out_pos ++] = data_buf_.back();
+			out_pos += 1
+		for sample_cpy_out in range(data_to_pop, expected_sample_number):
+			# cur_seg.buffer[out_pos++] = data_buf_.back()
+			out_pos += 1
+		if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+			print('Something wrong with the VAD algorithm\n')
+		cache["stats"].data_buf_start_frame += frm_cnt
+		cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
+		if first_frm_is_start_point:
+			cur_seg.contain_seg_start_point = True
+		if last_frm_is_end_point:
+			cur_seg.contain_seg_end_point = True
+	
+	def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
+		cache["stats"].lastest_confirmed_silence_frame = valid_frame
+		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+			self.PopDataBufTillFrame(valid_frame, cache=cache)
+		# silence_detected_callback_
+		# pass
+	
+	def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
+		cache["stats"].latest_confirmed_speech_frame = valid_frame
+		self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
+	
+	def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
+		if self.vad_opts.do_start_point_detection:
+			pass
+		if cache["stats"].confirmed_start_frame != -1:
+			print('not reset vad properly\n')
+		else:
+			cache["stats"].confirmed_start_frame = start_frame
+		
+		if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+			self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
+	
+	def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
+		for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
+			self.OnVoiceDetected(t, cache=cache)
+		if self.vad_opts.do_end_point_detection:
+			pass
+		if cache["stats"].confirmed_end_frame != -1:
+			print('not reset vad properly\n')
+		else:
+			cache["stats"].confirmed_end_frame = end_frame
+		if not fake_result:
+			cache["stats"].sil_frame = 0
+			self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache)
+		cache["stats"].number_end_time_detected += 1
+	
+	def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None:
+		if is_final_frame:
+			self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
+			cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+	
+	def GetLatency(self, cache: dict = {}) -> int:
+		return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
+	
+	def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
+		vad_latency = cache["windows_detector"].GetWinSize()
+		if self.vad_opts.do_extend:
+			vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
+		return vad_latency
+	
+	def GetFrameState(self, t: int, cache: dict = {}):
+		frame_state = FrameState.kFrameStateInvalid
+		cur_decibel = cache["stats"].decibel[t]
+		cur_snr = cur_decibel - cache["stats"].noise_average_decibel
+		# for each frame, calc log posterior probability of each state
+		if cur_decibel < self.vad_opts.decibel_thres:
+			frame_state = FrameState.kFrameStateSil
+			self.DetectOneFrame(frame_state, t, False, cache=cache)
+			return frame_state
+		
+		sum_score = 0.0
+		noise_prob = 0.0
+		assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
+		if len(cache["stats"].sil_pdf_ids) > 0:
+			assert len(cache["stats"].scores) == 1  # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
+			sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids]
+			sum_score = sum(sil_pdf_scores)
+			noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
+			total_score = 1.0
+			sum_score = total_score - sum_score
+		speech_prob = math.log(sum_score)
+		if self.vad_opts.output_frame_probs:
+			frame_prob = E2EVadFrameProb()
+			frame_prob.noise_prob = noise_prob
+			frame_prob.speech_prob = speech_prob
+			frame_prob.score = sum_score
+			frame_prob.frame_id = t
+			cache["stats"].frame_probs.append(frame_prob)
+		if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
+			if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
+				frame_state = FrameState.kFrameStateSpeech
+			else:
+				frame_state = FrameState.kFrameStateSil
+		else:
+			frame_state = FrameState.kFrameStateSil
+			if cache["stats"].noise_average_decibel < -99.9:
+				cache["stats"].noise_average_decibel = cur_decibel
+			else:
+				cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * (
+					self.vad_opts.noise_frame_num_used_for_snr
+					- 1)) / self.vad_opts.noise_frame_num_used_for_snr
+		
+		return frame_state
+	
+	def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
+	            is_final: bool = False
+	            ):
+		# if len(cache) == 0:
+		#     self.AllResetDetection()
+		# self.waveform = waveform  # compute decibel for each frame
+		cache["stats"].waveform = waveform
+		self.ComputeDecibel(cache=cache)
+		self.ComputeScores(feats, cache=cache)
+		if not is_final:
+			self.DetectCommonFrames(cache=cache)
+		else:
+			self.DetectLastFrames(cache=cache)
+		segments = []
+		for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
+			segment_batch = []
+			if len(cache["stats"].output_data_buf) > 0:
+				for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
+					if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
+						i].contain_seg_end_point):
+						continue
+					segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
+					segment_batch.append(segment)
+					cache["stats"].output_data_buf_offset += 1  # need update this parameter
+			if segment_batch:
+				segments.append(segment_batch)
+		# if is_final:
+		#     # reset class variables and clear the dict for the next query
+		#     self.AllResetDetection()
+		return segments
+	
+	def init_cache(self, cache: dict = {}, **kwargs):
+		cache["frontend"] = {}
+		cache["prev_samples"] = torch.empty(0)
+		cache["encoder"] = {}
+		windows_detector = WindowDetector(self.vad_opts.window_size_ms,
+		                                  self.vad_opts.sil_to_speech_time_thres,
+		                                  self.vad_opts.speech_to_sil_time_thres,
+		                                  self.vad_opts.frame_in_ms)
+		windows_detector.Reset()
+		
+		stats = Stats(sil_pdf_ids=self.vad_opts.sil_pdf_ids,
+		              max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres,
+		              speech_noise_thres=self.vad_opts.speech_noise_thres
+		              )
+		cache["windows_detector"] = windows_detector
+		cache["stats"] = stats
+		return cache
+	
+	def inference(self,
+	              data_in,
+	              data_lengths=None,
+	              key: list = None,
+	              tokenizer=None,
+	              frontend=None,
+	              cache: dict = {},
+	              **kwargs,
+	              ):
+		
+		if len(cache) == 0:
+			self.init_cache(cache, **kwargs)
+		
+		meta_data = {}
+		chunk_size = kwargs.get("chunk_size", 60000) # 50ms
+		chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
+		
+		time1 = time.perf_counter()
+		cfg = {"is_final": kwargs.get("is_final", False)}
+		audio_sample_list = load_audio_text_image_video(data_in,
+		                                                fs=frontend.fs,
+		                                                audio_fs=kwargs.get("fs", 16000),
+		                                                data_type=kwargs.get("data_type", "sound"),
+		                                                tokenizer=tokenizer,
+		                                                cache=cfg,
+		                                                )
+		_is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
+		
+		time2 = time.perf_counter()
+		meta_data["load_data"] = f"{time2 - time1:0.3f}"
+		assert len(audio_sample_list) == 1, "batch_size must be set 1"
+		
+		audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
+		
+		n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
+		m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
+		segments = []
+		for i in range(n):
+			kwargs["is_final"] = _is_final and i == n - 1
+			audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
+			
+			# extract fbank feats
+			speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
+			                                       frontend=frontend, cache=cache["frontend"],
+			                                       is_final=kwargs["is_final"])
+			time3 = time.perf_counter()
+			meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+			meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+			speech = speech.to(device=kwargs["device"])
+			speech_lengths = speech_lengths.to(device=kwargs["device"])
+			
+			batch = {
+				"feats": speech,
+				"waveform": cache["frontend"]["waveforms"],
+				"is_final": kwargs["is_final"],
+				"cache": cache
+			}
+			segments_i = self.forward(**batch)
+			if len(segments_i) > 0:
+				segments.extend(*segments_i)
+		
+		
+		cache["prev_samples"] = audio_sample[:-m]
+		if _is_final:
+			cache = {}
+		
+		ibest_writer = None
+		if ibest_writer is None and kwargs.get("output_dir") is not None:
+			writer = DatadirWriter(kwargs.get("output_dir"))
+			ibest_writer = writer[f"{1}best_recog"]
+		
+		results = []
+		result_i = {"key": key[0], "value": segments}
+		if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+			result_i = json.dumps(result_i)
+		
+		results.append(result_i)
+		
+		if ibest_writer is not None:
+			ibest_writer["text"][key[0]] = segments
+		
+		
+		return results, meta_data
+	
+	
+	def DetectCommonFrames(self, cache: dict = {}) -> int:
+		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+			return 0
+		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+			frame_state = FrameState.kFrameStateInvalid
+			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
+			self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
+		
+		return 0
+	
+	def DetectLastFrames(self, cache: dict = {}) -> int:
+		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+			return 0
+		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+			frame_state = FrameState.kFrameStateInvalid
+			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
+			if i != 0:
+				self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
+			else:
+				self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
+		
+		return 0
+	
+	def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
+		tmp_cur_frm_state = FrameState.kFrameStateInvalid
+		if cur_frm_state == FrameState.kFrameStateSpeech:
+			if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
+				tmp_cur_frm_state = FrameState.kFrameStateSpeech
+			else:
+				tmp_cur_frm_state = FrameState.kFrameStateSil
+		elif cur_frm_state == FrameState.kFrameStateSil:
+			tmp_cur_frm_state = FrameState.kFrameStateSil
+		state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache)
+		frm_shift_in_ms = self.vad_opts.frame_in_ms
+		if AudioChangeState.kChangeStateSil2Speech == state_change:
+			silence_frame_count = cache["stats"].continous_silence_frame_count
+			cache["stats"].continous_silence_frame_count = 0
+			cache["stats"].pre_end_silence_detected = False
+			start_frame = 0
+			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+				start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
+				self.OnVoiceStart(start_frame, cache=cache)
+				cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
+				for t in range(start_frame + 1, cur_frm_idx + 1):
+					self.OnVoiceDetected(t, cache=cache)
+			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+				for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
+					self.OnVoiceDetected(t, cache=cache)
+				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
+					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				elif not is_final_frame:
+					self.OnVoiceDetected(cur_frm_idx, cache=cache)
+				else:
+					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+			else:
+				pass
+		elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
+			cache["stats"].continous_silence_frame_count = 0
+			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+				pass
+			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
+					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				elif not is_final_frame:
+					self.OnVoiceDetected(cur_frm_idx, cache=cache)
+				else:
+					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+			else:
+				pass
+		elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
+			cache["stats"].continous_silence_frame_count = 0
+			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
+					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+					cache["stats"].max_time_out = True
+					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				elif not is_final_frame:
+					self.OnVoiceDetected(cur_frm_idx, cache=cache)
+				else:
+					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+			else:
+				pass
+		elif AudioChangeState.kChangeStateSil2Sil == state_change:
+			cache["stats"].continous_silence_frame_count += 1
+			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+				# silence timeout, return zero length decision
+				if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
+					cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
+					or (is_final_frame and cache["stats"].number_end_time_detected == 0):
+					for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
+						self.OnSilenceDetected(t, cache=cache)
+					self.OnVoiceStart(0, True, cache=cache)
+					self.OnVoiceEnd(0, True, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				else:
+					if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
+						self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
+			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+				if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
+					lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
+					if self.vad_opts.do_extend:
+						lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
+						lookback_frame -= 1
+						lookback_frame = max(0, lookback_frame)
+					self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
+					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+				elif self.vad_opts.do_extend and not is_final_frame:
+					if cache["stats"].continous_silence_frame_count <= int(
+						self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
+						self.OnVoiceDetected(cur_frm_idx, cache=cache)
+				else:
+					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+			else:
+				pass
+		
+		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
+			self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
+			self.ResetDetection(cache=cache)
 
 
 

--
Gitblit v1.9.1