From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/fsmn_vad_streaming/model.py | 1594 +++++++++++++++++++++++++++++++---------------------------
 1 files changed, 862 insertions(+), 732 deletions(-)

diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 7d71faa..04689be 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -19,768 +19,898 @@
 
 
 class VadStateMachine(Enum):
-	kVadInStateStartPointNotDetected = 1
-	kVadInStateInSpeechSegment = 2
-	kVadInStateEndPointDetected = 3
+    kVadInStateStartPointNotDetected = 1
+    kVadInStateInSpeechSegment = 2
+    kVadInStateEndPointDetected = 3
 
 
 class FrameState(Enum):
-	kFrameStateInvalid = -1
-	kFrameStateSpeech = 1
-	kFrameStateSil = 0
+    kFrameStateInvalid = -1
+    kFrameStateSpeech = 1
+    kFrameStateSil = 0
 
 
 # final voice/unvoice state per frame
 class AudioChangeState(Enum):
-	kChangeStateSpeech2Speech = 0
-	kChangeStateSpeech2Sil = 1
-	kChangeStateSil2Sil = 2
-	kChangeStateSil2Speech = 3
-	kChangeStateNoBegin = 4
-	kChangeStateInvalid = 5
+    kChangeStateSpeech2Speech = 0
+    kChangeStateSpeech2Sil = 1
+    kChangeStateSil2Sil = 2
+    kChangeStateSil2Speech = 3
+    kChangeStateNoBegin = 4
+    kChangeStateInvalid = 5
 
 
 class VadDetectMode(Enum):
-	kVadSingleUtteranceDetectMode = 0
-	kVadMutipleUtteranceDetectMode = 1
+    kVadSingleUtteranceDetectMode = 0
+    kVadMutipleUtteranceDetectMode = 1
 
 
 class VADXOptions:
-	"""
-	Author: Speech Lab of DAMO Academy, Alibaba Group
-	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-	https://arxiv.org/abs/1803.05030
-	"""
-	
-	def __init__(
-		self,
-		sample_rate: int = 16000,
-		detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
-		snr_mode: int = 0,
-		max_end_silence_time: int = 800,
-		max_start_silence_time: int = 3000,
-		do_start_point_detection: bool = True,
-		do_end_point_detection: bool = True,
-		window_size_ms: int = 200,
-		sil_to_speech_time_thres: int = 150,
-		speech_to_sil_time_thres: int = 150,
-		speech_2_noise_ratio: float = 1.0,
-		do_extend: int = 1,
-		lookback_time_start_point: int = 200,
-		lookahead_time_end_point: int = 100,
-		max_single_segment_time: int = 60000,
-		nn_eval_block_size: int = 8,
-		dcd_block_size: int = 4,
-		snr_thres: int = -100.0,
-		noise_frame_num_used_for_snr: int = 100,
-		decibel_thres: int = -100.0,
-		speech_noise_thres: float = 0.6,
-		fe_prior_thres: float = 1e-4,
-		silence_pdf_num: int = 1,
-		sil_pdf_ids: List[int] = [0],
-		speech_noise_thresh_low: float = -0.1,
-		speech_noise_thresh_high: float = 0.3,
-		output_frame_probs: bool = False,
-		frame_in_ms: int = 10,
-		frame_length_ms: int = 25,
-		**kwargs,
-	):
-		self.sample_rate = sample_rate
-		self.detect_mode = detect_mode
-		self.snr_mode = snr_mode
-		self.max_end_silence_time = max_end_silence_time
-		self.max_start_silence_time = max_start_silence_time
-		self.do_start_point_detection = do_start_point_detection
-		self.do_end_point_detection = do_end_point_detection
-		self.window_size_ms = window_size_ms
-		self.sil_to_speech_time_thres = sil_to_speech_time_thres
-		self.speech_to_sil_time_thres = speech_to_sil_time_thres
-		self.speech_2_noise_ratio = speech_2_noise_ratio
-		self.do_extend = do_extend
-		self.lookback_time_start_point = lookback_time_start_point
-		self.lookahead_time_end_point = lookahead_time_end_point
-		self.max_single_segment_time = max_single_segment_time
-		self.nn_eval_block_size = nn_eval_block_size
-		self.dcd_block_size = dcd_block_size
-		self.snr_thres = snr_thres
-		self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
-		self.decibel_thres = decibel_thres
-		self.speech_noise_thres = speech_noise_thres
-		self.fe_prior_thres = fe_prior_thres
-		self.silence_pdf_num = silence_pdf_num
-		self.sil_pdf_ids = sil_pdf_ids
-		self.speech_noise_thresh_low = speech_noise_thresh_low
-		self.speech_noise_thresh_high = speech_noise_thresh_high
-		self.output_frame_probs = output_frame_probs
-		self.frame_in_ms = frame_in_ms
-		self.frame_length_ms = frame_length_ms
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
+    def __init__(
+        self,
+        sample_rate: int = 16000,
+        detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
+        snr_mode: int = 0,
+        max_end_silence_time: int = 800,
+        max_start_silence_time: int = 3000,
+        do_start_point_detection: bool = True,
+        do_end_point_detection: bool = True,
+        window_size_ms: int = 200,
+        sil_to_speech_time_thres: int = 150,
+        speech_to_sil_time_thres: int = 150,
+        speech_2_noise_ratio: float = 1.0,
+        do_extend: int = 1,
+        lookback_time_start_point: int = 200,
+        lookahead_time_end_point: int = 100,
+        max_single_segment_time: int = 60000,
+        nn_eval_block_size: int = 8,
+        dcd_block_size: int = 4,
+        snr_thres: int = -100.0,
+        noise_frame_num_used_for_snr: int = 100,
+        decibel_thres: int = -100.0,
+        speech_noise_thres: float = 0.6,
+        fe_prior_thres: float = 1e-4,
+        silence_pdf_num: int = 1,
+        sil_pdf_ids: List[int] = [0],
+        speech_noise_thresh_low: float = -0.1,
+        speech_noise_thresh_high: float = 0.3,
+        output_frame_probs: bool = False,
+        frame_in_ms: int = 10,
+        frame_length_ms: int = 25,
+        **kwargs,
+    ):
+        self.sample_rate = sample_rate
+        self.detect_mode = detect_mode
+        self.snr_mode = snr_mode
+        self.max_end_silence_time = max_end_silence_time
+        self.max_start_silence_time = max_start_silence_time
+        self.do_start_point_detection = do_start_point_detection
+        self.do_end_point_detection = do_end_point_detection
+        self.window_size_ms = window_size_ms
+        self.sil_to_speech_time_thres = sil_to_speech_time_thres
+        self.speech_to_sil_time_thres = speech_to_sil_time_thres
+        self.speech_2_noise_ratio = speech_2_noise_ratio
+        self.do_extend = do_extend
+        self.lookback_time_start_point = lookback_time_start_point
+        self.lookahead_time_end_point = lookahead_time_end_point
+        self.max_single_segment_time = max_single_segment_time
+        self.nn_eval_block_size = nn_eval_block_size
+        self.dcd_block_size = dcd_block_size
+        self.snr_thres = snr_thres
+        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
+        self.decibel_thres = decibel_thres
+        self.speech_noise_thres = speech_noise_thres
+        self.fe_prior_thres = fe_prior_thres
+        self.silence_pdf_num = silence_pdf_num
+        self.sil_pdf_ids = sil_pdf_ids
+        self.speech_noise_thresh_low = speech_noise_thresh_low
+        self.speech_noise_thresh_high = speech_noise_thresh_high
+        self.output_frame_probs = output_frame_probs
+        self.frame_in_ms = frame_in_ms
+        self.frame_length_ms = frame_length_ms
 
 
 class E2EVadSpeechBufWithDoa(object):
-	"""
-	Author: Speech Lab of DAMO Academy, Alibaba Group
-	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-	https://arxiv.org/abs/1803.05030
-	"""
-	
-	def __init__(self):
-		self.start_ms = 0
-		self.end_ms = 0
-		self.buffer = []
-		self.contain_seg_start_point = False
-		self.contain_seg_end_point = False
-		self.doa = 0
-	
-	def Reset(self):
-		self.start_ms = 0
-		self.end_ms = 0
-		self.buffer = []
-		self.contain_seg_start_point = False
-		self.contain_seg_end_point = False
-		self.doa = 0
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
+    def __init__(self):
+        self.start_ms = 0
+        self.end_ms = 0
+        self.buffer = []
+        self.contain_seg_start_point = False
+        self.contain_seg_end_point = False
+        self.doa = 0
+
+    def Reset(self):
+        self.start_ms = 0
+        self.end_ms = 0
+        self.buffer = []
+        self.contain_seg_start_point = False
+        self.contain_seg_end_point = False
+        self.doa = 0
 
 
 class E2EVadFrameProb(object):
-	"""
-	Author: Speech Lab of DAMO Academy, Alibaba Group
-	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-	https://arxiv.org/abs/1803.05030
-	"""
-	
-	def __init__(self):
-		self.noise_prob = 0.0
-		self.speech_prob = 0.0
-		self.score = 0.0
-		self.frame_id = 0
-		self.frm_state = 0
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
+    def __init__(self):
+        self.noise_prob = 0.0
+        self.speech_prob = 0.0
+        self.score = 0.0
+        self.frame_id = 0
+        self.frm_state = 0
 
 
 class WindowDetector(object):
-	"""
-	Author: Speech Lab of DAMO Academy, Alibaba Group
-	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-	https://arxiv.org/abs/1803.05030
-	"""
-	
-	def __init__(self, window_size_ms: int,
-	             sil_to_speech_time: int,
-	             speech_to_sil_time: int,
-	             frame_size_ms: int):
-		self.window_size_ms = window_size_ms
-		self.sil_to_speech_time = sil_to_speech_time
-		self.speech_to_sil_time = speech_to_sil_time
-		self.frame_size_ms = frame_size_ms
-		
-		self.win_size_frame = int(window_size_ms / frame_size_ms)
-		self.win_sum = 0
-		self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
-		
-		self.cur_win_pos = 0
-		self.pre_frame_state = FrameState.kFrameStateSil
-		self.cur_frame_state = FrameState.kFrameStateSil
-		self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
-		self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
-		
-		self.voice_last_frame_count = 0
-		self.noise_last_frame_count = 0
-		self.hydre_frame_count = 0
-	
-	def Reset(self) -> None:
-		self.cur_win_pos = 0
-		self.win_sum = 0
-		self.win_state = [0] * self.win_size_frame
-		self.pre_frame_state = FrameState.kFrameStateSil
-		self.cur_frame_state = FrameState.kFrameStateSil
-		self.voice_last_frame_count = 0
-		self.noise_last_frame_count = 0
-		self.hydre_frame_count = 0
-	
-	def GetWinSize(self) -> int:
-		return int(self.win_size_frame)
-	
-	def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState:
-		cur_frame_state = FrameState.kFrameStateSil
-		if frameState == FrameState.kFrameStateSpeech:
-			cur_frame_state = 1
-		elif frameState == FrameState.kFrameStateSil:
-			cur_frame_state = 0
-		else:
-			return AudioChangeState.kChangeStateInvalid
-		self.win_sum -= self.win_state[self.cur_win_pos]
-		self.win_sum += cur_frame_state
-		self.win_state[self.cur_win_pos] = cur_frame_state
-		self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
-		
-		if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
-			self.pre_frame_state = FrameState.kFrameStateSpeech
-			return AudioChangeState.kChangeStateSil2Speech
-		
-		if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
-			self.pre_frame_state = FrameState.kFrameStateSil
-			return AudioChangeState.kChangeStateSpeech2Sil
-		
-		if self.pre_frame_state == FrameState.kFrameStateSil:
-			return AudioChangeState.kChangeStateSil2Sil
-		if self.pre_frame_state == FrameState.kFrameStateSpeech:
-			return AudioChangeState.kChangeStateSpeech2Speech
-		return AudioChangeState.kChangeStateInvalid
-	
-	def FrameSizeMs(self) -> int:
-		return int(self.frame_size_ms)
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
+    def __init__(
+        self,
+        window_size_ms: int,
+        sil_to_speech_time: int,
+        speech_to_sil_time: int,
+        frame_size_ms: int,
+    ):
+        self.window_size_ms = window_size_ms
+        self.sil_to_speech_time = sil_to_speech_time
+        self.speech_to_sil_time = speech_to_sil_time
+        self.frame_size_ms = frame_size_ms
+
+        self.win_size_frame = int(window_size_ms / frame_size_ms)
+        self.win_sum = 0
+        self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
+
+        self.cur_win_pos = 0
+        self.pre_frame_state = FrameState.kFrameStateSil
+        self.cur_frame_state = FrameState.kFrameStateSil
+        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
+        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
+
+        self.voice_last_frame_count = 0
+        self.noise_last_frame_count = 0
+        self.hydre_frame_count = 0
+
+    def Reset(self) -> None:
+        self.cur_win_pos = 0
+        self.win_sum = 0
+        self.win_state = [0] * self.win_size_frame
+        self.pre_frame_state = FrameState.kFrameStateSil
+        self.cur_frame_state = FrameState.kFrameStateSil
+        self.voice_last_frame_count = 0
+        self.noise_last_frame_count = 0
+        self.hydre_frame_count = 0
+
+    def GetWinSize(self) -> int:
+        return int(self.win_size_frame)
+
+    def DetectOneFrame(
+        self, frameState: FrameState, frame_count: int, cache: dict = {}
+    ) -> AudioChangeState:
+        cur_frame_state = FrameState.kFrameStateSil
+        if frameState == FrameState.kFrameStateSpeech:
+            cur_frame_state = 1
+        elif frameState == FrameState.kFrameStateSil:
+            cur_frame_state = 0
+        else:
+            return AudioChangeState.kChangeStateInvalid
+        self.win_sum -= self.win_state[self.cur_win_pos]
+        self.win_sum += cur_frame_state
+        self.win_state[self.cur_win_pos] = cur_frame_state
+        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
+
+        if (
+            self.pre_frame_state == FrameState.kFrameStateSil
+            and self.win_sum >= self.sil_to_speech_frmcnt_thres
+        ):
+            self.pre_frame_state = FrameState.kFrameStateSpeech
+            return AudioChangeState.kChangeStateSil2Speech
+
+        if (
+            self.pre_frame_state == FrameState.kFrameStateSpeech
+            and self.win_sum <= self.speech_to_sil_frmcnt_thres
+        ):
+            self.pre_frame_state = FrameState.kFrameStateSil
+            return AudioChangeState.kChangeStateSpeech2Sil
+
+        if self.pre_frame_state == FrameState.kFrameStateSil:
+            return AudioChangeState.kChangeStateSil2Sil
+        if self.pre_frame_state == FrameState.kFrameStateSpeech:
+            return AudioChangeState.kChangeStateSpeech2Speech
+        return AudioChangeState.kChangeStateInvalid
+
+    def FrameSizeMs(self) -> int:
+        return int(self.frame_size_ms)
 
 
 class Stats(object):
-	def __init__(self,
-	             sil_pdf_ids,
-	             max_end_sil_frame_cnt_thresh,
-	             speech_noise_thres,
-	             ):
-		self.data_buf_start_frame = 0
-		self.frm_cnt = 0
-		self.latest_confirmed_speech_frame = 0
-		self.lastest_confirmed_silence_frame = -1
-		self.continous_silence_frame_count = 0
-		self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
-		self.confirmed_start_frame = -1
-		self.confirmed_end_frame = -1
-		self.number_end_time_detected = 0
-		self.sil_frame = 0
-		self.sil_pdf_ids = sil_pdf_ids
-		self.noise_average_decibel = -100.0
-		self.pre_end_silence_detected = False
-		self.next_seg = True
-		
-		self.output_data_buf = []
-		self.output_data_buf_offset = 0
-		self.frame_probs = []
-		self.max_end_sil_frame_cnt_thresh = max_end_sil_frame_cnt_thresh
-		self.speech_noise_thres = speech_noise_thres
-		self.scores = None
-		self.max_time_out = False
-		self.decibel = []
-		self.data_buf = None
-		self.data_buf_all = None
-		self.waveform = None
-		self.last_drop_frames = 0
+    def __init__(
+        self,
+        sil_pdf_ids,
+        max_end_sil_frame_cnt_thresh,
+        speech_noise_thres,
+    ):
+        self.data_buf_start_frame = 0
+        self.frm_cnt = 0
+        self.latest_confirmed_speech_frame = 0
+        self.lastest_confirmed_silence_frame = -1
+        self.continous_silence_frame_count = 0
+        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+        self.confirmed_start_frame = -1
+        self.confirmed_end_frame = -1
+        self.number_end_time_detected = 0
+        self.sil_frame = 0
+        self.sil_pdf_ids = sil_pdf_ids
+        self.noise_average_decibel = -100.0
+        self.pre_end_silence_detected = False
+        self.next_seg = True
+
+        self.output_data_buf = []
+        self.output_data_buf_offset = 0
+        self.frame_probs = []
+        self.max_end_sil_frame_cnt_thresh = max_end_sil_frame_cnt_thresh
+        self.speech_noise_thres = speech_noise_thres
+        self.scores = None
+        self.max_time_out = False
+        self.decibel = []
+        self.data_buf = None
+        self.data_buf_all = None
+        self.waveform = None
+        self.last_drop_frames = 0
 
 
 @tables.register("model_classes", "FsmnVADStreaming")
 class FsmnVADStreaming(nn.Module):
-	"""
-	Author: Speech Lab of DAMO Academy, Alibaba Group
-	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
-	https://arxiv.org/abs/1803.05030
-	"""
-	
-	def __init__(self,
-	             encoder: str = None,
-	             encoder_conf: Optional[Dict] = None,
-	             vad_post_args: Dict[str, Any] = None,
-	             **kwargs,
-	             ):
-		super().__init__()
-		self.vad_opts = VADXOptions(**kwargs)
-		
-		encoder_class = tables.encoder_classes.get(encoder)
-		encoder = encoder_class(**encoder_conf)
-		self.encoder = encoder
-		self.encoder_conf = encoder_conf
-	
-	def ResetDetection(self, cache: dict = {}):
-		cache["stats"].continous_silence_frame_count = 0
-		cache["stats"].latest_confirmed_speech_frame = 0
-		cache["stats"].lastest_confirmed_silence_frame = -1
-		cache["stats"].confirmed_start_frame = -1
-		cache["stats"].confirmed_end_frame = -1
-		cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
-		cache["windows_detector"].Reset()
-		cache["stats"].sil_frame = 0
-		cache["stats"].frame_probs = []
-		
-		if cache["stats"].output_data_buf:
-			assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
-			drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
-			real_drop_frames = drop_frames - cache["stats"].last_drop_frames
-			cache["stats"].last_drop_frames = drop_frames
-			cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(
-				self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
-			cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
-			cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
-	
-	def ComputeDecibel(self, cache: dict = {}) -> None:
-		frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
-		frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
-		if cache["stats"].data_buf_all is None:
-			cache["stats"].data_buf_all = cache["stats"].waveform[
-				0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
-			cache["stats"].data_buf = cache["stats"].data_buf_all
-		else:
-			cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
-		for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
-			cache["stats"].decibel.append(
-				10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \
-				                0.000001))
-	
-	def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
-		scores = self.encoder(feats, cache=cache["encoder"]).to('cpu')  # return B * T * D
-		assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
-		self.vad_opts.nn_eval_block_size = scores.shape[1]
-		cache["stats"].frm_cnt += scores.shape[1]  # count total frames
-		if cache["stats"].scores is None:
-			cache["stats"].scores = scores  # the first calculation
-		else:
-			cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
-	
-	def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None:  # need check again
-		while cache["stats"].data_buf_start_frame < frame_idx:
-			if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
-				cache["stats"].data_buf_start_frame += 1
-				cache["stats"].data_buf = cache["stats"].data_buf_all[
-				                          (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
-					                          self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
-	
-	def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
-	                       last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None:
-		self.PopDataBufTillFrame(start_frm, cache=cache)
-		expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
-		if last_frm_is_end_point:
-			extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
-			                          self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
-			expected_sample_number += int(extra_sample)
-		if end_point_is_sent_end:
-			expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
-		if len(cache["stats"].data_buf) < expected_sample_number:
-			print('error in calling pop data_buf\n')
-		
-		if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
-			cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
-			cache["stats"].output_data_buf[-1].Reset()
-			cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
-			cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
-			cache["stats"].output_data_buf[-1].doa = 0
-		cur_seg = cache["stats"].output_data_buf[-1]
-		if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
-			print('warning\n')
-		out_pos = len(cur_seg.buffer)  # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
-		data_to_pop = 0
-		if end_point_is_sent_end:
-			data_to_pop = expected_sample_number
-		else:
-			data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
-		if data_to_pop > len(cache["stats"].data_buf):
-			print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
-			data_to_pop = len(cache["stats"].data_buf)
-			expected_sample_number = len(cache["stats"].data_buf)
-		
-		cur_seg.doa = 0
-		for sample_cpy_out in range(0, data_to_pop):
-			# cur_seg.buffer[out_pos ++] = data_buf_.back();
-			out_pos += 1
-		for sample_cpy_out in range(data_to_pop, expected_sample_number):
-			# cur_seg.buffer[out_pos++] = data_buf_.back()
-			out_pos += 1
-		if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
-			print('Something wrong with the VAD algorithm\n')
-		cache["stats"].data_buf_start_frame += frm_cnt
-		cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
-		if first_frm_is_start_point:
-			cur_seg.contain_seg_start_point = True
-		if last_frm_is_end_point:
-			cur_seg.contain_seg_end_point = True
-	
-	def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
-		cache["stats"].lastest_confirmed_silence_frame = valid_frame
-		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-			self.PopDataBufTillFrame(valid_frame, cache=cache)
-	
-	# silence_detected_callback_
-	# pass
-	
-	def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
-		cache["stats"].latest_confirmed_speech_frame = valid_frame
-		self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
-	
-	def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
-		if self.vad_opts.do_start_point_detection:
-			pass
-		if cache["stats"].confirmed_start_frame != -1:
-			print('not reset vad properly\n')
-		else:
-			cache["stats"].confirmed_start_frame = start_frame
-		
-		if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-			self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
-	
-	def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None:
-		for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
-			self.OnVoiceDetected(t, cache=cache)
-		if self.vad_opts.do_end_point_detection:
-			pass
-		if cache["stats"].confirmed_end_frame != -1:
-			print('not reset vad properly\n')
-		else:
-			cache["stats"].confirmed_end_frame = end_frame
-		if not fake_result:
-			cache["stats"].sil_frame = 0
-			self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache)
-		cache["stats"].number_end_time_detected += 1
-	
-	def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None:
-		if is_final_frame:
-			self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
-			cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-	
-	def GetLatency(self, cache: dict = {}) -> int:
-		return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
-	
-	def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
-		vad_latency = cache["windows_detector"].GetWinSize()
-		if self.vad_opts.do_extend:
-			vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
-		return vad_latency
-	
-	def GetFrameState(self, t: int, cache: dict = {}):
-		frame_state = FrameState.kFrameStateInvalid
-		cur_decibel = cache["stats"].decibel[t]
-		cur_snr = cur_decibel - cache["stats"].noise_average_decibel
-		# for each frame, calc log posterior probability of each state
-		if cur_decibel < self.vad_opts.decibel_thres:
-			frame_state = FrameState.kFrameStateSil
-			self.DetectOneFrame(frame_state, t, False, cache=cache)
-			return frame_state
-		
-		sum_score = 0.0
-		noise_prob = 0.0
-		assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
-		if len(cache["stats"].sil_pdf_ids) > 0:
-			assert len(cache["stats"].scores) == 1  # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
-			sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids]
-			sum_score = sum(sil_pdf_scores)
-			noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
-			total_score = 1.0
-			sum_score = total_score - sum_score
-		speech_prob = math.log(sum_score)
-		if self.vad_opts.output_frame_probs:
-			frame_prob = E2EVadFrameProb()
-			frame_prob.noise_prob = noise_prob
-			frame_prob.speech_prob = speech_prob
-			frame_prob.score = sum_score
-			frame_prob.frame_id = t
-			cache["stats"].frame_probs.append(frame_prob)
-		if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
-			if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
-				frame_state = FrameState.kFrameStateSpeech
-			else:
-				frame_state = FrameState.kFrameStateSil
-		else:
-			frame_state = FrameState.kFrameStateSil
-			if cache["stats"].noise_average_decibel < -99.9:
-				cache["stats"].noise_average_decibel = cur_decibel
-			else:
-				cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * (
-					self.vad_opts.noise_frame_num_used_for_snr
-					- 1)) / self.vad_opts.noise_frame_num_used_for_snr
-		
-		return frame_state
-	
-	def forward(self, feats: torch.Tensor,
-	            waveform: torch.tensor,
-	            cache: dict = {},
-	            is_final: bool = False,
-	            **kwargs,
-	            ):
-		# if len(cache) == 0:
-		#     self.AllResetDetection()
-		# self.waveform = waveform  # compute decibel for each frame
-		cache["stats"].waveform = waveform
-		is_streaming_input = kwargs.get("is_streaming_input", True)
-		self.ComputeDecibel(cache=cache)
-		self.ComputeScores(feats, cache=cache)
-		if not is_final:
-			self.DetectCommonFrames(cache=cache)
-		else:
-			self.DetectLastFrames(cache=cache)
-		segments = []
-		for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
-			segment_batch = []
-			if len(cache["stats"].output_data_buf) > 0:
-				for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
-					if is_streaming_input: # in this case, return [beg, -1], [], [-1, end], [beg, end]
-						if not cache["stats"].output_data_buf[i].contain_seg_start_point:
-							continue
-						if not cache["stats"].next_seg and not cache["stats"].output_data_buf[i].contain_seg_end_point:
-							continue
-						start_ms = cache["stats"].output_data_buf[i].start_ms if cache["stats"].next_seg else -1
-						if cache["stats"].output_data_buf[i].contain_seg_end_point:
-							end_ms = cache["stats"].output_data_buf[i].end_ms
-							cache["stats"].next_seg = True
-							cache["stats"].output_data_buf_offset += 1
-						else:
-							end_ms = -1
-							cache["stats"].next_seg = False
-						segment = [start_ms, end_ms]
-						
-					else: # in this case, return [beg, end]
-						
-						if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not
-						cache["stats"].output_data_buf[
-							i].contain_seg_end_point):
-							continue
-						segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
-						cache["stats"].output_data_buf_offset += 1  # need update this parameter
-					
-					segment_batch.append(segment)
-					
-			if segment_batch:
-				segments.append(segment_batch)
-		# if is_final:
-		#     # reset class variables and clear the dict for the next query
-		#     self.AllResetDetection()
-		return segments
-	
-	def init_cache(self, cache: dict = {}, **kwargs):
-		
-		cache["frontend"] = {}
-		cache["prev_samples"] = torch.empty(0)
-		cache["encoder"] = {}
-		windows_detector = WindowDetector(self.vad_opts.window_size_ms,
-		                                  self.vad_opts.sil_to_speech_time_thres,
-		                                  self.vad_opts.speech_to_sil_time_thres,
-		                                  self.vad_opts.frame_in_ms)
-		windows_detector.Reset()
-		
-		stats = Stats(sil_pdf_ids=self.vad_opts.sil_pdf_ids,
-		              max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres,
-		              speech_noise_thres=self.vad_opts.speech_noise_thres
-		              )
-		cache["windows_detector"] = windows_detector
-		cache["stats"] = stats
-		return cache
-	
-	def inference(self,
-	              data_in,
-	              data_lengths=None,
-	              key: list = None,
-	              tokenizer=None,
-	              frontend=None,
-	              cache: dict = {},
-	              **kwargs,
-	              ):
-		
-		if len(cache) == 0:
-			self.init_cache(cache, **kwargs)
-		
-		meta_data = {}
-		chunk_size = kwargs.get("chunk_size", 60000)  # 50ms
-		chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
-		
-		time1 = time.perf_counter()
-		is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True)
-		is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True)
-		cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input}
-		audio_sample_list = load_audio_text_image_video(data_in,
-		                                                fs=frontend.fs,
-		                                                audio_fs=kwargs.get("fs", 16000),
-		                                                data_type=kwargs.get("data_type", "sound"),
-		                                                tokenizer=tokenizer,
-		                                                cache=cfg,
-		                                                )
-		_is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
-		is_streaming_input = cfg["is_streaming_input"]
-		time2 = time.perf_counter()
-		meta_data["load_data"] = f"{time2 - time1:0.3f}"
-		assert len(audio_sample_list) == 1, "batch_size must be set 1"
-		
-		audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
-		
-		n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
-		m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
-		segments = []
-		for i in range(n):
-			kwargs["is_final"] = _is_final and i == n - 1
-			audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
-			
-			# extract fbank feats
-			speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
-			                                       frontend=frontend, cache=cache["frontend"],
-			                                       is_final=kwargs["is_final"])
-			time3 = time.perf_counter()
-			meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-			meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-			speech = speech.to(device=kwargs["device"])
-			speech_lengths = speech_lengths.to(device=kwargs["device"])
-			
-			batch = {
-				"feats": speech,
-				"waveform": cache["frontend"]["waveforms"],
-				"is_final": kwargs["is_final"],
-				"cache": cache,
-				"is_streaming_input": is_streaming_input
-			}
-			segments_i = self.forward(**batch)
-			if len(segments_i) > 0:
-				segments.extend(*segments_i)
-		
-		cache["prev_samples"] = audio_sample[:-m]
-		if _is_final:
-			self.init_cache(cache)
-		
-		ibest_writer = None
-		if kwargs.get("output_dir") is not None:
-			if not hasattr(self, "writer"):
-				self.writer = DatadirWriter(kwargs.get("output_dir"))
-			ibest_writer = self.writer[f"{1}best_recog"]
-		
-		results = []
-		result_i = {"key": key[0], "value": segments}
-		# if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
-		# 	result_i = json.dumps(result_i)
-		
-		results.append(result_i)
-		
-		if ibest_writer is not None:
-			ibest_writer["text"][key[0]] = segments
-		
-		return results, meta_data
-	
-	def export(self, **kwargs):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
 
-		from .export_meta import export_rebuild_model
-		models = export_rebuild_model(model=self, **kwargs)
-		return models
+    def __init__(
+        self,
+        encoder: str = None,
+        encoder_conf: Optional[Dict] = None,
+        vad_post_args: Dict[str, Any] = None,
+        **kwargs,
+    ):
+        super().__init__()
+        self.vad_opts = VADXOptions(**kwargs)
 
-	def DetectCommonFrames(self, cache: dict = {}) -> int:
-		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
-			return 0
-		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
-			frame_state = FrameState.kFrameStateInvalid
-			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
-			                                 cache=cache)
-			self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
-		
-		return 0
-	
-	def DetectLastFrames(self, cache: dict = {}) -> int:
-		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
-			return 0
-		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
-			frame_state = FrameState.kFrameStateInvalid
-			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
-			                                 cache=cache)
-			if i != 0:
-				self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
-			else:
-				self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
-		
-		return 0
-	
-	def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool,
-	                   cache: dict = {}) -> None:
-		tmp_cur_frm_state = FrameState.kFrameStateInvalid
-		if cur_frm_state == FrameState.kFrameStateSpeech:
-			if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
-				tmp_cur_frm_state = FrameState.kFrameStateSpeech
-			else:
-				tmp_cur_frm_state = FrameState.kFrameStateSil
-		elif cur_frm_state == FrameState.kFrameStateSil:
-			tmp_cur_frm_state = FrameState.kFrameStateSil
-		state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache)
-		frm_shift_in_ms = self.vad_opts.frame_in_ms
-		if AudioChangeState.kChangeStateSil2Speech == state_change:
-			silence_frame_count = cache["stats"].continous_silence_frame_count
-			cache["stats"].continous_silence_frame_count = 0
-			cache["stats"].pre_end_silence_detected = False
-			start_frame = 0
-			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-				start_frame = max(cache["stats"].data_buf_start_frame,
-				                  cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
-				self.OnVoiceStart(start_frame, cache=cache)
-				cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
-				for t in range(start_frame + 1, cur_frm_idx + 1):
-					self.OnVoiceDetected(t, cache=cache)
-			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-				for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
-					self.OnVoiceDetected(t, cache=cache)
-				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-				elif not is_final_frame:
-					self.OnVoiceDetected(cur_frm_idx, cache=cache)
-				else:
-					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-			else:
-				pass
-		elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
-			cache["stats"].continous_silence_frame_count = 0
-			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-				pass
-			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-				elif not is_final_frame:
-					self.OnVoiceDetected(cur_frm_idx, cache=cache)
-				else:
-					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-			else:
-				pass
-		elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
-			cache["stats"].continous_silence_frame_count = 0
-			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-				if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-					cache["stats"].max_time_out = True
-					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-				elif not is_final_frame:
-					self.OnVoiceDetected(cur_frm_idx, cache=cache)
-				else:
-					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-			else:
-				pass
-		elif AudioChangeState.kChangeStateSil2Sil == state_change:
-			cache["stats"].continous_silence_frame_count += 1
-			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-				# silence timeout, return zero length decision
-				if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
-					cache[
-						"stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
-					or (is_final_frame and cache["stats"].number_end_time_detected == 0):
-					for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
-						self.OnSilenceDetected(t, cache=cache)
-					self.OnVoiceStart(0, True, cache=cache)
-					self.OnVoiceEnd(0, True, False, cache=cache)
-					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-				else:
-					if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
-						self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
-			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-				if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[
-					"stats"].max_end_sil_frame_cnt_thresh:
-					lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
-					if self.vad_opts.do_extend:
-						lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
-						lookback_frame -= 1
-						lookback_frame = max(0, lookback_frame)
-					self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
-					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-				elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
-					self.vad_opts.max_single_segment_time / frm_shift_in_ms:
-					self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
-					cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
-				elif self.vad_opts.do_extend and not is_final_frame:
-					if cache["stats"].continous_silence_frame_count <= int(
-						self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
-						self.OnVoiceDetected(cur_frm_idx, cache=cache)
-				else:
-					self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
-			else:
-				pass
-		
-		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
-			self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
-			self.ResetDetection(cache=cache)
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(**encoder_conf)
+        self.encoder = encoder
+        self.encoder_conf = encoder_conf
 
+    def ResetDetection(self, cache: dict = {}):
+        cache["stats"].continous_silence_frame_count = 0
+        cache["stats"].latest_confirmed_speech_frame = 0
+        cache["stats"].lastest_confirmed_silence_frame = -1
+        cache["stats"].confirmed_start_frame = -1
+        cache["stats"].confirmed_end_frame = -1
+        cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+        cache["windows_detector"].Reset()
+        cache["stats"].sil_frame = 0
+        cache["stats"].frame_probs = []
 
+        if cache["stats"].output_data_buf:
+            assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
+            drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
+            real_drop_frames = drop_frames - cache["stats"].last_drop_frames
+            cache["stats"].last_drop_frames = drop_frames
+            cache["stats"].data_buf_all = cache["stats"].data_buf_all[
+                real_drop_frames
+                * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
+            ]
+            cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
+            cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
+
+    def ComputeDecibel(self, cache: dict = {}) -> None:
+        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
+        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+        if cache["stats"].data_buf_all is None:
+            cache["stats"].data_buf_all = cache["stats"].waveform[
+                0
+            ]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
+            cache["stats"].data_buf = cache["stats"].data_buf_all
+        else:
+            cache["stats"].data_buf_all = torch.cat(
+                (cache["stats"].data_buf_all, cache["stats"].waveform[0])
+            )
+        for offset in range(
+            0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length
+        ):
+            cache["stats"].decibel.append(
+                10
+                * math.log10(
+                    (cache["stats"].waveform[0][offset : offset + frame_sample_length])
+                    .square()
+                    .sum()
+                    + 0.000001
+                )
+            )
+
+    def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
+        scores = self.encoder(feats, cache=cache["encoder"]).to("cpu")  # return B * T * D
+        assert (
+            scores.shape[1] == feats.shape[1]
+        ), "The shape between feats and scores does not match"
+        self.vad_opts.nn_eval_block_size = scores.shape[1]
+        cache["stats"].frm_cnt += scores.shape[1]  # count total frames
+        if cache["stats"].scores is None:
+            cache["stats"].scores = scores  # the first calculation
+        else:
+            cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
+
+    def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None:  # need check again
+        while cache["stats"].data_buf_start_frame < frame_idx:
+            if len(cache["stats"].data_buf) >= int(
+                self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
+            ):
+                cache["stats"].data_buf_start_frame += 1
+                cache["stats"].data_buf = cache["stats"].data_buf_all[
+                    (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames)
+                    * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) :
+                ]
+
+    def PopDataToOutputBuf(
+        self,
+        start_frm: int,
+        frm_cnt: int,
+        first_frm_is_start_point: bool,
+        last_frm_is_end_point: bool,
+        end_point_is_sent_end: bool,
+        cache: dict = {},
+    ) -> None:
+        self.PopDataBufTillFrame(start_frm, cache=cache)
+        expected_sample_number = int(
+            frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
+        )
+        if last_frm_is_end_point:
+            extra_sample = max(
+                0,
+                int(
+                    self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
+                    - self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
+                ),
+            )
+            expected_sample_number += int(extra_sample)
+        if end_point_is_sent_end:
+            expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
+        if len(cache["stats"].data_buf) < expected_sample_number:
+            print("error in calling pop data_buf\n")
+
+        if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
+            cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
+            cache["stats"].output_data_buf[-1].Reset()
+            cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
+            cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
+            cache["stats"].output_data_buf[-1].doa = 0
+        cur_seg = cache["stats"].output_data_buf[-1]
+        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+            print("warning\n")
+        out_pos = len(cur_seg.buffer)  # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
+        data_to_pop = 0
+        if end_point_is_sent_end:
+            data_to_pop = expected_sample_number
+        else:
+            data_to_pop = int(
+                frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
+            )
+        if data_to_pop > len(cache["stats"].data_buf):
+            print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
+            data_to_pop = len(cache["stats"].data_buf)
+            expected_sample_number = len(cache["stats"].data_buf)
+
+        cur_seg.doa = 0
+        for sample_cpy_out in range(0, data_to_pop):
+            # cur_seg.buffer[out_pos ++] = data_buf_.back();
+            out_pos += 1
+        for sample_cpy_out in range(data_to_pop, expected_sample_number):
+            # cur_seg.buffer[out_pos++] = data_buf_.back()
+            out_pos += 1
+        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+            print("Something wrong with the VAD algorithm\n")
+        cache["stats"].data_buf_start_frame += frm_cnt
+        cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
+        if first_frm_is_start_point:
+            cur_seg.contain_seg_start_point = True
+        if last_frm_is_end_point:
+            cur_seg.contain_seg_end_point = True
+
+    def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
+        cache["stats"].lastest_confirmed_silence_frame = valid_frame
+        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+            self.PopDataBufTillFrame(valid_frame, cache=cache)
+
+    # silence_detected_callback_
+    # pass
+
+    def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
+        cache["stats"].latest_confirmed_speech_frame = valid_frame
+        self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
+
+    def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
+        if self.vad_opts.do_start_point_detection:
+            pass
+        if cache["stats"].confirmed_start_frame != -1:
+            print("not reset vad properly\n")
+        else:
+            cache["stats"].confirmed_start_frame = start_frame
+
+        if (
+            not fake_result
+            and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected
+        ):
+            self.PopDataToOutputBuf(
+                cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache
+            )
+
+    def OnVoiceEnd(
+        self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}
+    ) -> None:
+        for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
+            self.OnVoiceDetected(t, cache=cache)
+        if self.vad_opts.do_end_point_detection:
+            pass
+        if cache["stats"].confirmed_end_frame != -1:
+            print("not reset vad properly\n")
+        else:
+            cache["stats"].confirmed_end_frame = end_frame
+        if not fake_result:
+            cache["stats"].sil_frame = 0
+            self.PopDataToOutputBuf(
+                cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache
+            )
+        cache["stats"].number_end_time_detected += 1
+
+    def MaybeOnVoiceEndIfLastFrame(
+        self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}
+    ) -> None:
+        if is_final_frame:
+            self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
+            cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+
+    def GetLatency(self, cache: dict = {}) -> int:
+        return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
+
+    def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
+        vad_latency = cache["windows_detector"].GetWinSize()
+        if self.vad_opts.do_extend:
+            vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
+        return vad_latency
+
+    def GetFrameState(self, t: int, cache: dict = {}):
+        frame_state = FrameState.kFrameStateInvalid
+        cur_decibel = cache["stats"].decibel[t]
+        cur_snr = cur_decibel - cache["stats"].noise_average_decibel
+        # for each frame, calc log posterior probability of each state
+        if cur_decibel < self.vad_opts.decibel_thres:
+            frame_state = FrameState.kFrameStateSil
+            self.DetectOneFrame(frame_state, t, False, cache=cache)
+            return frame_state
+
+        sum_score = 0.0
+        noise_prob = 0.0
+        assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
+        if len(cache["stats"].sil_pdf_ids) > 0:
+            assert len(cache["stats"].scores) == 1  # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
+            sil_pdf_scores = [
+                cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids
+            ]
+            sum_score = sum(sil_pdf_scores)
+            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
+            total_score = 1.0
+            sum_score = total_score - sum_score
+        speech_prob = math.log(sum_score)
+        if self.vad_opts.output_frame_probs:
+            frame_prob = E2EVadFrameProb()
+            frame_prob.noise_prob = noise_prob
+            frame_prob.speech_prob = speech_prob
+            frame_prob.score = sum_score
+            frame_prob.frame_id = t
+            cache["stats"].frame_probs.append(frame_prob)
+        if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
+            if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
+                frame_state = FrameState.kFrameStateSpeech
+            else:
+                frame_state = FrameState.kFrameStateSil
+        else:
+            frame_state = FrameState.kFrameStateSil
+            if cache["stats"].noise_average_decibel < -99.9:
+                cache["stats"].noise_average_decibel = cur_decibel
+            else:
+                cache["stats"].noise_average_decibel = (
+                    cur_decibel
+                    + cache["stats"].noise_average_decibel
+                    * (self.vad_opts.noise_frame_num_used_for_snr - 1)
+                ) / self.vad_opts.noise_frame_num_used_for_snr
+
+        return frame_state
+
+    def forward(
+        self,
+        feats: torch.Tensor,
+        waveform: torch.tensor,
+        cache: dict = {},
+        is_final: bool = False,
+        **kwargs,
+    ):
+        # if len(cache) == 0:
+        #     self.AllResetDetection()
+        # self.waveform = waveform  # compute decibel for each frame
+        cache["stats"].waveform = waveform
+        is_streaming_input = kwargs.get("is_streaming_input", True)
+        self.ComputeDecibel(cache=cache)
+        self.ComputeScores(feats, cache=cache)
+        if not is_final:
+            self.DetectCommonFrames(cache=cache)
+        else:
+            self.DetectLastFrames(cache=cache)
+        segments = []
+        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
+            segment_batch = []
+            if len(cache["stats"].output_data_buf) > 0:
+                for i in range(
+                    cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)
+                ):
+                    if (
+                        is_streaming_input
+                    ):  # in this case, return [beg, -1], [], [-1, end], [beg, end]
+                        if not cache["stats"].output_data_buf[i].contain_seg_start_point:
+                            continue
+                        if (
+                            not cache["stats"].next_seg
+                            and not cache["stats"].output_data_buf[i].contain_seg_end_point
+                        ):
+                            continue
+                        start_ms = (
+                            cache["stats"].output_data_buf[i].start_ms
+                            if cache["stats"].next_seg
+                            else -1
+                        )
+                        if cache["stats"].output_data_buf[i].contain_seg_end_point:
+                            end_ms = cache["stats"].output_data_buf[i].end_ms
+                            cache["stats"].next_seg = True
+                            cache["stats"].output_data_buf_offset += 1
+                        else:
+                            end_ms = -1
+                            cache["stats"].next_seg = False
+                        segment = [start_ms, end_ms]
+
+                    else:  # in this case, return [beg, end]
+
+                        if not is_final and (
+                            not cache["stats"].output_data_buf[i].contain_seg_start_point
+                            or not cache["stats"].output_data_buf[i].contain_seg_end_point
+                        ):
+                            continue
+                        segment = [
+                            cache["stats"].output_data_buf[i].start_ms,
+                            cache["stats"].output_data_buf[i].end_ms,
+                        ]
+                        cache["stats"].output_data_buf_offset += 1  # need update this parameter
+
+                    segment_batch.append(segment)
+
+            if segment_batch:
+                segments.append(segment_batch)
+        # if is_final:
+        #     # reset class variables and clear the dict for the next query
+        #     self.AllResetDetection()
+        return segments
+
+    def init_cache(self, cache: dict = {}, **kwargs):
+
+        cache["frontend"] = {}
+        cache["prev_samples"] = torch.empty(0)
+        cache["encoder"] = {}
+
+        if kwargs.get("max_end_silence_time") is not None:
+            # update the max_end_silence_time
+            self.vad_opts.max_end_silence_time = kwargs.get("max_end_silence_time")
+
+        windows_detector = WindowDetector(
+            self.vad_opts.window_size_ms,
+            self.vad_opts.sil_to_speech_time_thres,
+            self.vad_opts.speech_to_sil_time_thres,
+            self.vad_opts.frame_in_ms,
+        )
+        windows_detector.Reset()
+
+        stats = Stats(
+            sil_pdf_ids=self.vad_opts.sil_pdf_ids,
+            max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time
+            - self.vad_opts.speech_to_sil_time_thres,
+            speech_noise_thres=self.vad_opts.speech_noise_thres,
+        )
+        cache["windows_detector"] = windows_detector
+        cache["stats"] = stats
+        return cache
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        cache: dict = {},
+        **kwargs,
+    ):
+
+        if len(cache) == 0:
+            self.init_cache(cache, **kwargs)
+
+        meta_data = {}
+        chunk_size = kwargs.get("chunk_size", 60000)  # 50ms
+        chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
+
+        time1 = time.perf_counter()
+        is_streaming_input = (
+            kwargs.get("is_streaming_input", False)
+            if chunk_size >= 15000
+            else kwargs.get("is_streaming_input", True)
+        )
+        is_final = (
+            kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True)
+        )
+        cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input}
+        audio_sample_list = load_audio_text_image_video(
+            data_in,
+            fs=frontend.fs,
+            audio_fs=kwargs.get("fs", 16000),
+            data_type=kwargs.get("data_type", "sound"),
+            tokenizer=tokenizer,
+            cache=cfg,
+        )
+        _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
+        is_streaming_input = cfg["is_streaming_input"]
+        time2 = time.perf_counter()
+        meta_data["load_data"] = f"{time2 - time1:0.3f}"
+        assert len(audio_sample_list) == 1, "batch_size must be set 1"
+
+        audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
+
+        n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
+        m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
+        segments = []
+        for i in range(n):
+            kwargs["is_final"] = _is_final and i == n - 1
+            audio_sample_i = audio_sample[i * chunk_stride_samples : (i + 1) * chunk_stride_samples]
+
+            # extract fbank feats
+            speech, speech_lengths = extract_fbank(
+                [audio_sample_i],
+                data_type=kwargs.get("data_type", "sound"),
+                frontend=frontend,
+                cache=cache["frontend"],
+                is_final=kwargs["is_final"],
+            )
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = (
+                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+            )
+            speech = speech.to(device=kwargs["device"])
+            speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+            batch = {
+                "feats": speech,
+                "waveform": cache["frontend"]["waveforms"],
+                "is_final": kwargs["is_final"],
+                "cache": cache,
+                "is_streaming_input": is_streaming_input,
+            }
+            segments_i = self.forward(**batch)
+            if len(segments_i) > 0:
+                segments.extend(*segments_i)
+
+        cache["prev_samples"] = audio_sample[:-m]
+        if _is_final:
+            self.init_cache(cache)
+
+        ibest_writer = None
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+            ibest_writer = self.writer[f"{1}best_recog"]
+
+        results = []
+        result_i = {"key": key[0], "value": segments}
+        # if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+        # 	result_i = json.dumps(result_i)
+
+        results.append(result_i)
+
+        if ibest_writer is not None:
+            ibest_writer["text"][key[0]] = segments
+
+        return results, meta_data
+
+    def export(self, **kwargs):
+
+        from .export_meta import export_rebuild_model
+
+        models = export_rebuild_model(model=self, **kwargs)
+        return models
+
+    def DetectCommonFrames(self, cache: dict = {}) -> int:
+        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+            return 0
+        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+            frame_state = FrameState.kFrameStateInvalid
+            frame_state = self.GetFrameState(
+                cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
+            )
+            self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
+
+        return 0
+
+    def DetectLastFrames(self, cache: dict = {}) -> int:
+        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+            return 0
+        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+            frame_state = FrameState.kFrameStateInvalid
+            frame_state = self.GetFrameState(
+                cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache
+            )
+            if i != 0:
+                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
+            else:
+                self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
+
+        return 0
+
+    def DetectOneFrame(
+        self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}
+    ) -> None:
+        tmp_cur_frm_state = FrameState.kFrameStateInvalid
+        if cur_frm_state == FrameState.kFrameStateSpeech:
+            if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
+                tmp_cur_frm_state = FrameState.kFrameStateSpeech
+            else:
+                tmp_cur_frm_state = FrameState.kFrameStateSil
+        elif cur_frm_state == FrameState.kFrameStateSil:
+            tmp_cur_frm_state = FrameState.kFrameStateSil
+        state_change = cache["windows_detector"].DetectOneFrame(
+            tmp_cur_frm_state, cur_frm_idx, cache=cache
+        )
+        frm_shift_in_ms = self.vad_opts.frame_in_ms
+        if AudioChangeState.kChangeStateSil2Speech == state_change:
+            silence_frame_count = cache["stats"].continous_silence_frame_count
+            cache["stats"].continous_silence_frame_count = 0
+            cache["stats"].pre_end_silence_detected = False
+            start_frame = 0
+            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+                start_frame = max(
+                    cache["stats"].data_buf_start_frame,
+                    cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache),
+                )
+                self.OnVoiceStart(start_frame, cache=cache)
+                cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
+                for t in range(start_frame + 1, cur_frm_idx + 1):
+                    self.OnVoiceDetected(t, cache=cache)
+            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
+                    self.OnVoiceDetected(t, cache=cache)
+                if (
+                    cur_frm_idx - cache["stats"].confirmed_start_frame + 1
+                    > self.vad_opts.max_single_segment_time / frm_shift_in_ms
+                ):
+                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif not is_final_frame:
+                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+            else:
+                pass
+        elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
+            cache["stats"].continous_silence_frame_count = 0
+            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+                pass
+            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                if (
+                    cur_frm_idx - cache["stats"].confirmed_start_frame + 1
+                    > self.vad_opts.max_single_segment_time / frm_shift_in_ms
+                ):
+                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif not is_final_frame:
+                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+            else:
+                pass
+        elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
+            cache["stats"].continous_silence_frame_count = 0
+            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                if (
+                    cur_frm_idx - cache["stats"].confirmed_start_frame + 1
+                    > self.vad_opts.max_single_segment_time / frm_shift_in_ms
+                ):
+                    cache["stats"].max_time_out = True
+                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif not is_final_frame:
+                    self.OnVoiceDetected(cur_frm_idx, cache=cache)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+            else:
+                pass
+        elif AudioChangeState.kChangeStateSil2Sil == state_change:
+            cache["stats"].continous_silence_frame_count += 1
+            if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+                # silence timeout, return zero length decision
+                if (
+                    (self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value)
+                    and (
+                        cache["stats"].continous_silence_frame_count * frm_shift_in_ms
+                        > self.vad_opts.max_start_silence_time
+                    )
+                ) or (is_final_frame and cache["stats"].number_end_time_detected == 0):
+                    for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
+                        self.OnSilenceDetected(t, cache=cache)
+                    self.OnVoiceStart(0, True, cache=cache)
+                    self.OnVoiceEnd(0, True, False, cache=cache)
+                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                else:
+                    if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
+                        self.OnSilenceDetected(
+                            cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache
+                        )
+            elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                if (
+                    cache["stats"].continous_silence_frame_count * frm_shift_in_ms
+                    >= cache["stats"].max_end_sil_frame_cnt_thresh
+                ):
+                    lookback_frame = int(
+                        cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms
+                    )
+                    if self.vad_opts.do_extend:
+                        lookback_frame -= int(
+                            self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
+                        )
+                        lookback_frame -= 1
+                        lookback_frame = max(0, lookback_frame)
+                    self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
+                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif (
+                    cur_frm_idx - cache["stats"].confirmed_start_frame + 1
+                    > self.vad_opts.max_single_segment_time / frm_shift_in_ms
+                ):
+                    self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
+                    cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif self.vad_opts.do_extend and not is_final_frame:
+                    if cache["stats"].continous_silence_frame_count <= int(
+                        self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
+                    ):
+                        self.OnVoiceDetected(cur_frm_idx, cache=cache)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
+            else:
+                pass
+
+        if (
+            cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected
+            and self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value
+        ):
+            self.ResetDetection(cache=cache)

--
Gitblit v1.9.1