From 37d7764ecf0e8cc1a14f59b8b9cd1c914da8b005 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 21 一月 2024 21:06:52 +0800
Subject: [PATCH] Funasr1.0 (#1277)

---
 funasr/models/fsmn_vad_streaming/model.py |   76 +++++++++++++++++++++++---------------
 1 files changed, 46 insertions(+), 30 deletions(-)

diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index becfd56..76eee81 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -15,7 +15,7 @@
 from typing import List, Tuple, Dict, Any, Optional
 
 from funasr.utils.datadir_writer import DatadirWriter
-from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
 
 class VadStateMachine(Enum):
@@ -23,10 +23,12 @@
 	kVadInStateInSpeechSegment = 2
 	kVadInStateEndPointDetected = 3
 
+
 class FrameState(Enum):
 	kFrameStateInvalid = -1
 	kFrameStateSpeech = 1
 	kFrameStateSil = 0
+
 
 # final voice/unvoice state per frame
 class AudioChangeState(Enum):
@@ -37,9 +39,11 @@
 	kChangeStateNoBegin = 4
 	kChangeStateInvalid = 5
 
+
 class VadDetectMode(Enum):
 	kVadSingleUtteranceDetectMode = 0
 	kVadMutipleUtteranceDetectMode = 1
+
 
 class VADXOptions:
 	"""
@@ -47,6 +51,7 @@
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(
 		self,
 		sample_rate: int = 16000,
@@ -117,6 +122,7 @@
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self):
 		self.start_ms = 0
 		self.end_ms = 0
@@ -140,6 +146,7 @@
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self):
 		self.noise_prob = 0.0
 		self.speech_prob = 0.0
@@ -154,6 +161,7 @@
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self, window_size_ms: int,
 	             sil_to_speech_time: int,
 	             speech_to_sil_time: int,
@@ -190,7 +198,7 @@
 	def GetWinSize(self) -> int:
 		return int(self.win_size_frame)
 	
-	def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
+	def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState:
 		cur_frame_state = FrameState.kFrameStateSil
 		if frameState == FrameState.kFrameStateSpeech:
 			cur_frame_state = 1
@@ -220,13 +228,13 @@
 	def FrameSizeMs(self) -> int:
 		return int(self.frame_size_ms)
 
+
 class Stats(object):
 	def __init__(self,
 	             sil_pdf_ids,
 	             max_end_sil_frame_cnt_thresh,
 	             speech_noise_thres,
 	             ):
-		
 		self.data_buf_start_frame = 0
 		self.frm_cnt = 0
 		self.latest_confirmed_speech_frame = 0
@@ -255,6 +263,7 @@
 		self.waveform = None
 		self.last_drop_frames = 0
 
+
 @tables.register("model_classes", "FsmnVADStreaming")
 class FsmnVADStreaming(nn.Module):
 	"""
@@ -262,6 +271,7 @@
 	Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 	https://arxiv.org/abs/1803.05030
 	"""
+	
 	def __init__(self,
 	             encoder: str = None,
 	             encoder_conf: Optional[Dict] = None,
@@ -274,7 +284,6 @@
 		encoder_class = tables.encoder_classes.get(encoder)
 		encoder = encoder_class(**encoder_conf)
 		self.encoder = encoder
-	
 	
 	def ResetDetection(self, cache: dict = {}):
 		cache["stats"].continous_silence_frame_count = 0
@@ -292,7 +301,8 @@
 			drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
 			real_drop_frames = drop_frames - cache["stats"].last_drop_frames
 			cache["stats"].last_drop_frames = drop_frames
-			cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+			cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(
+				self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
 			cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
 			cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
 	
@@ -300,7 +310,8 @@
 		frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
 		frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
 		if cache["stats"].data_buf_all is None:
-			cache["stats"].data_buf_all = cache["stats"].waveform[0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
+			cache["stats"].data_buf_all = cache["stats"].waveform[
+				0]  # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
 			cache["stats"].data_buf = cache["stats"].data_buf_all
 		else:
 			cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
@@ -319,15 +330,16 @@
 		else:
 			cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
 	
-	def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None:  # need check again
+	def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None:  # need check again
 		while cache["stats"].data_buf_start_frame < frame_idx:
 			if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
 				cache["stats"].data_buf_start_frame += 1
-				cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
-					self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+				cache["stats"].data_buf = cache["stats"].data_buf_all[
+				                          (cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
+					                          self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
 	
 	def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
-	                       last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
+	                       last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None:
 		self.PopDataBufTillFrame(start_frm, cache=cache)
 		expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
 		if last_frm_is_end_point:
@@ -379,14 +391,15 @@
 		cache["stats"].lastest_confirmed_silence_frame = valid_frame
 		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
 			self.PopDataBufTillFrame(valid_frame, cache=cache)
-		# silence_detected_callback_
-		# pass
 	
-	def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
+	# silence_detected_callback_
+	# pass
+	
+	def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
 		cache["stats"].latest_confirmed_speech_frame = valid_frame
 		self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
 	
-	def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
+	def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
 		if self.vad_opts.do_start_point_detection:
 			pass
 		if cache["stats"].confirmed_start_frame != -1:
@@ -397,7 +410,7 @@
 		if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
 			self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
 	
-	def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
+	def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None:
 		for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
 			self.OnVoiceDetected(t, cache=cache)
 		if self.vad_opts.do_end_point_detection:
@@ -487,7 +500,8 @@
 			segment_batch = []
 			if len(cache["stats"].output_data_buf) > 0:
 				for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
-					if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
+					if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not
+					cache["stats"].output_data_buf[
 						i].contain_seg_end_point):
 						continue
 					segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
@@ -499,9 +513,9 @@
 		#     # reset class variables and clear the dict for the next query
 		#     self.AllResetDetection()
 		return segments
-
+	
 	def init_cache(self, cache: dict = {}, **kwargs):
-    
+		
 		cache["frontend"] = {}
 		cache["prev_samples"] = torch.empty(0)
 		cache["encoder"] = {}
@@ -528,12 +542,12 @@
 	              cache: dict = {},
 	              **kwargs,
 	              ):
-
+		
 		if len(cache) == 0:
 			self.init_cache(cache, **kwargs)
 		
 		meta_data = {}
-		chunk_size = kwargs.get("chunk_size", 60000) # 50ms
+		chunk_size = kwargs.get("chunk_size", 60000)  # 50ms
 		chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
 		
 		time1 = time.perf_counter()
@@ -580,7 +594,6 @@
 			if len(segments_i) > 0:
 				segments.extend(*segments_i)
 		
-		
 		cache["prev_samples"] = audio_sample[:-m]
 		if _is_final:
 			self.init_cache(cache)
@@ -600,16 +613,15 @@
 		if ibest_writer is not None:
 			ibest_writer["text"][key[0]] = segments
 		
-		
 		return results, meta_data
-	
 	
 	def DetectCommonFrames(self, cache: dict = {}) -> int:
 		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
 			return 0
 		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
 			frame_state = FrameState.kFrameStateInvalid
-			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
+			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
+			                                 cache=cache)
 			self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
 		
 		return 0
@@ -619,7 +631,8 @@
 			return 0
 		for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
 			frame_state = FrameState.kFrameStateInvalid
-			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
+			frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
+			                                 cache=cache)
 			if i != 0:
 				self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
 			else:
@@ -627,7 +640,8 @@
 		
 		return 0
 	
-	def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
+	def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool,
+	                   cache: dict = {}) -> None:
 		tmp_cur_frm_state = FrameState.kFrameStateInvalid
 		if cur_frm_state == FrameState.kFrameStateSpeech:
 			if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
@@ -644,7 +658,8 @@
 			cache["stats"].pre_end_silence_detected = False
 			start_frame = 0
 			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
-				start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
+				start_frame = max(cache["stats"].data_buf_start_frame,
+				                  cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
 				self.OnVoiceStart(start_frame, cache=cache)
 				cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
 				for t in range(start_frame + 1, cur_frm_idx + 1):
@@ -696,7 +711,8 @@
 			if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
 				# silence timeout, return zero length decision
 				if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
-					cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
+					cache[
+						"stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
 					or (is_final_frame and cache["stats"].number_end_time_detected == 0):
 					for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
 						self.OnSilenceDetected(t, cache=cache)
@@ -707,7 +723,8 @@
 					if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
 						self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
 			elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
-				if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
+				if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[
+					"stats"].max_end_sil_frame_cnt_thresh:
 					lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
 					if self.vad_opts.do_extend:
 						lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
@@ -731,6 +748,5 @@
 		if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
 			self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
 			self.ResetDetection(cache=cache)
-
 
 

--
Gitblit v1.9.1