From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example

---
 funasr/models/fsmn_vad_streaming/model.py |  116 +++++++++++++++++++++++++++++++++++++++------------------
 1 files changed, 79 insertions(+), 37 deletions(-)

diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 943cb47..4fd18c8 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -15,7 +15,7 @@
 from typing import List, Tuple, Dict, Any, Optional
 
 from funasr.utils.datadir_writer import DatadirWriter
-from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
 
 class VadStateMachine(Enum):
@@ -23,10 +23,12 @@
 	kVadInStateInSpeechSegment = 2
 	kVadInStateEndPointDetected = 3
 
+
 class FrameState(Enum):
 	kFrameStateInvalid = -1
 	kFrameStateSpeech = 1
 	kFrameStateSil = 0
+
 
 # final voice/unvoice state per frame
 class AudioChangeState(Enum):
@@ -37,9 +39,11 @@
 	kChangeStateNoBegin = 4
 	kChangeStateInvalid = 5
 
+
 class VadDetectMode(Enum):
 	kVadSingleUtteranceDetectMode = 0
 	kVadMutipleUtteranceDetectMode = 1
+
 
 class VADXOptions:
 	"""
@@ -47,6 +51,7 @@
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(
 		self,
 		sample_rate: int = 16000,
@@ -117,6 +122,7 @@
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self):
 		self.start_ms = 0
 		self.end_ms = 0
@@ -140,6 +146,7 @@
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self):
 		self.noise_prob = 0.0
 		self.speech_prob = 0.0
@@ -154,6 +161,7 @@
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self, window_size_ms: int,
 	             sil_to_speech_time: int,
 	             speech_to_sil_time: int,
@@ -190,7 +198,7 @@
 	def GetWinSize(self) -> int:
 		return int(self.win_size_frame)
 	
-	def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
+	def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState:
 		cur_frame_state = FrameState.kFrameStateSil
 		if frameState == FrameState.kFrameStateSpeech:
 			cur_frame_state = 1
@@ -220,13 +228,13 @@
 	def FrameSizeMs(self) -> int:
 		return int(self.frame_size_ms)
 
+
 class Stats(object):
 	def __init__(self,
 	             sil_pdf_ids,
 	             max_end_sil_frame_cnt_thresh,
 	             speech_noise_thres,
 	             ):
-		
 		self.data_buf_start_frame = 0
 		self.frm_cnt = 0
 		self.latest_confirmed_speech_frame = 0
@@ -263,6 +271,7 @@
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self,
 	             encoder: str = None,
 	             encoder_conf: Optional[Dict] = None,
@@ -275,7 +284,6 @@
 		encoder_class = tables.encoder_classes.get(encoder)
 		encoder = encoder_class(**encoder_conf)
 		self.encoder = encoder
-	
 	
 	def ResetDetection(self, cache: dict = {}):
 		cache["stats"].continous_silence_frame_count = 0
@@ -293,7 +301,8 @@
 			drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
 			real_drop_frames = drop_frames - cache["stats"].last_drop_frames
 			cache["stats"].last_drop_frames = drop_frames
-			cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+			cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(
+				self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
 			cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
 			cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
 	
@@ -301,7 +310,8 @@
 		frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
 		frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
 		if cache["stats"].data_buf_all is None:
-			cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
+			cache["stats"].data_buf_all = cache["stats"].waveform[
+				0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
 			cache["stats"].data_buf = cache["stats"].data_buf_all
 		else:
 			cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
@@ -320,15 +330,16 @@
 		else:
 			cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
 	
-	def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
+	def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None:  # need check again
 		while cache["stats"].data_buf_start_frame < frame_idx:
 			if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
 				cache["stats"].data_buf_start_frame += 1
-				cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
-					self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+				cache["stats"].data_buf = cache["stats"].data_buf_all[
+				                          (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
+					                          self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
 	
 	def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
-	                       last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
+	                       last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None:
 		self.PopDataBufTillFrame(start_frm, cache=cache)
 		expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
 		if last_frm_is_end_point:
@@ -380,14 +391,15 @@
 		cache["stats"].lastest_confirmed_silence_frame = valid_frame
 		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
 			self.PopDataBufTillFrame(valid_frame, cache=cache)
-		# silence_detected_callback_
-		# pass
 	
-	def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
+	# silence_detected_callback_
+	# pass
+	
+	def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
 		cache["stats"].latest_confirmed_speech_frame = valid_frame
 		self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
 	
-	def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
+	def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
 		if self.vad_opts.do_start_point_detection:
 			pass
 		if cache["stats"].confirmed_start_frame != -1:
@@ -398,7 +410,7 @@
 		if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
 			self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
 	
-	def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
+	def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None:
 		for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
 			self.OnVoiceDetected(t, cache=cache)
 		if self.vad_opts.do_end_point_detection:
@@ -470,13 +482,17 @@
 		
 		return frame_state
 	
-	def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
-	            is_final: bool = False
+	def forward(self, feats: torch.Tensor,
+	            waveform: torch.tensor,
+	            cache: dict = {},
+	            is_final: bool = False,
+	            **kwargs,
 	            ):
 		# if len(cache) == 0:
 		#     self.AllResetDetection()
 		# self.waveform = waveform  # compute decibel for each frame
 		cache["stats"].waveform = waveform
+		is_streaming_input = kwargs.get("is_streaming_input", True)
 		self.ComputeDecibel(cache=cache)
 		self.ComputeScores(feats, cache=cache)
 		if not is_final:
@@ -488,12 +504,32 @@
 			segment_batch = []
 			if len(cache["stats"].output_data_buf) > 0:
 				for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
-					if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
-						i].contain_seg_end_point):
-						continue
-					segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
+					if is_streaming_input: # in this case, return [beg, -1], [], [-1, end], [beg, end]
+						if not cache["stats"].output_data_buf[i].contain_seg_start_point:
+							continue
+						if not cache["stats"].next_seg and not cache["stats"].output_data_buf[i].contain_seg_end_point:
+							continue
+						start_ms = cache["stats"].output_data_buf[i].start_ms if cache["stats"].next_seg else -1
+						if cache["stats"].output_data_buf[i].contain_seg_end_point:
+							end_ms = cache["stats"].output_data_buf[i].end_ms
+							cache["stats"].next_seg = True
+							cache["stats"].output_data_buf_offset += 1
+						else:
+							end_ms = -1
+							cache["stats"].next_seg = False
+						segment = [start_ms, end_ms]
+						
+					else: # in this case, return [beg, end]
+						
+						if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not
+						cache["stats"].output_data_buf[
+							i].contain_seg_end_point):
+							continue
+						segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
+						cache["stats"].output_data_buf_offset += 1  # need update this parameter
+					
 					segment_batch.append(segment)
-					cache["stats"].output_data_buf_offset += 1  # need update this parameter
+					
 			if segment_batch:
 				segments.append(segment_batch)
 		# if is_final:
@@ -502,6 +538,7 @@
 		return segments
 	
 	def init_cache(self, cache: dict = {}, **kwargs):
+		
 		cache["frontend"] = {}
 		cache["prev_samples"] = torch.empty(0)
 		cache["encoder"] = {}
@@ -533,11 +570,13 @@
 			self.init_cache(cache, **kwargs)
 		
 		meta_data = {}
-		chunk_size = kwargs.get("chunk_size", 60000) # 50ms
+		chunk_size = kwargs.get("chunk_size", 60000)  # 50ms
 		chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
 		
 		time1 = time.perf_counter()
-		cfg = {"is_final": kwargs.get("is_final", False)}
+		is_streaming_input = kwargs.get("is_streaming_input", False) if chunk_size >= 15000 else kwargs.get("is_streaming_input", True)
+		is_final = kwargs.get("is_final", False) if is_streaming_input else kwargs.get("is_final", True)
+		cfg = {"is_final": is_final, "is_streaming_input": is_streaming_input}
 		audio_sample_list = load_audio_text_image_video(data_in,
 		                                                fs=frontend.fs,
 		                                                audio_fs=kwargs.get("fs", 16000),
@@ -546,7 +585,7 @@
 		                                                cache=cfg,
 		                                                )
 		_is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
-		
+		is_streaming_input = cfg["is_streaming_input"]
 		time2 = time.perf_counter()
 		meta_data["load_data"] = f"{time2 - time1:0.3f}"
 		assert len(audio_sample_list) == 1, "batch_size must be set 1"
@@ -574,16 +613,16 @@
 				"feats": speech,
 				"waveform": cache["frontend"]["waveforms"],
 				"is_final": kwargs["is_final"],
-				"cache": cache
+				"cache": cache,
+				"is_streaming_input": is_streaming_input
 			}
 			segments_i = self.forward(**batch)
 			if len(segments_i) > 0:
 				segments.extend(*segments_i)
 		
-		
 		cache["prev_samples"] = audio_sample[:-m]
 		if _is_final:
-			cache = {}
+			self.init_cache(cache)
 		
 		ibest_writer = None
 		if ibest_writer is None and kwargs.get("output_dir") is not None:
@@ -600,16 +639,15 @@
 		if ibest_writer is not None:
 			ibest_writer["text"][key[0]] = segments
 		
-		
 		return results, meta_data
-	
 	
 	def DetectCommonFrames(self, cache: dict = {}) -> int:
 		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
 			return 0
 		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
 			frame_state = FrameState.kFrameStateInvalid
-			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
+			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
+			                                 cache=cache)
 			self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
 		
 		return 0
@@ -619,7 +657,8 @@
 			return 0
 		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
 			frame_state = FrameState.kFrameStateInvalid
-			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
+			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
+			                                 cache=cache)
 			if i != 0:
 				self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
 			else:
@@ -627,7 +666,8 @@
 		
 		return 0
 	
-	def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
+	def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool,
+	                   cache: dict = {}) -> None:
 		tmp_cur_frm_state = FrameState.kFrameStateInvalid
 		if cur_frm_state == FrameState.kFrameStateSpeech:
 			if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
@@ -644,7 +684,8 @@
 			cache["stats"].pre_end_silence_detected = False
 			start_frame = 0
 			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-				start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
+				start_frame = max(cache["stats"].data_buf_start_frame,
+				                  cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
 				self.OnVoiceStart(start_frame, cache=cache)
 				cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
 				for t in range(start_frame + 1, cur_frm_idx + 1):
@@ -696,7 +737,8 @@
 			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
 				# silence timeout, return zero length decision
 				if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
-					cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
+					cache[
+						"stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
 					or (is_final_frame and cache["stats"].number_end_time_detected == 0):
 					for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
 						self.OnSilenceDetected(t, cache=cache)
@@ -707,7 +749,8 @@
 					if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
 						self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
 			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-				if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
+				if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[
+					"stats"].max_end_sil_frame_cnt_thresh:
 					lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
 					if self.vad_opts.do_extend:
 						lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
@@ -731,6 +774,5 @@
 		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
 			self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
 			self.ResetDetection(cache=cache)
-
 
 

--
Gitblit v1.9.1