From 47088b8d1ebe42b6c376236c19184ef4f440cc0d Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 11 一月 2024 00:09:36 +0800
Subject: [PATCH] funasr1.0 paraformer_streaming

---
 runtime/python/onnxruntime/setup.py         |    2 
 funasr/models/paraformer_streaming/model.py |  153 ++++++++++++++++++++++++++++++++++++++++-----------
 2 files changed, 121 insertions(+), 34 deletions(-)

diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index 498d363..304c0f7 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -375,7 +375,7 @@
 		
 		return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
 	
-	def calc_predictor_chunk(self, encoder_out, cache=None):
+	def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None):
 		
 		pre_acoustic_embeds, pre_token_length = \
 			self.predictor.forward_chunk(encoder_out, cache["encoder"])
@@ -389,48 +389,72 @@
 		decoder_out = torch.log_softmax(decoder_out, dim=-1)
 		return decoder_out, ys_pad_lens
 	
-	def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+	def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
 		decoder_outs = self.decoder.forward_chunk(
 			encoder_out, sematic_embeds, cache["decoder"]
 		)
 		decoder_out = decoder_outs
 		decoder_out = torch.log_softmax(decoder_out, dim=-1)
-		return decoder_out
+		return decoder_out, ys_pad_lens
+	
+	def init_cache(self, cache: dict = {}, **kwargs):
+		chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+		encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
+		decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
+		batch_size = 1
 
-	def generate(self,
-	             speech: torch.Tensor,
-	             speech_lengths: torch.Tensor,
-	             tokenizer=None,
-	             **kwargs,
-	             ):
+		enc_output_size = kwargs["encoder_conf"]["output_size"]
+		feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
+		cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
+		            "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
+		            "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
+		            "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+		            "tail_chunk": False}
+		cache["encoder"] = cache_encoder
 		
-		is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
-		print(is_use_ctc)
-		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
+		            "chunk_size": chunk_size}
+		cache["decoder"] = cache_decoder
+		cache["frontend"] = {}
+		cache["prev_samples"] = []
 		
-		if self.beam_search is None and (is_use_lm or is_use_ctc):
-			logging.info("enable beam_search")
-			self.init_beam_search(speech, speech_lengths, **kwargs)
-			self.nbest = kwargs.get("nbest", 1)
+		return cache
+	
+	def generate_chunk(self,
+	                   speech,
+	                   speech_lengths=None,
+	                   key: list = None,
+	                   tokenizer=None,
+	                   frontend=None,
+	                   **kwargs,
+	                   ):
+		cache = kwargs.get("cache", {})
+		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
 		
-		# Forward Encoder
-		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache)
 		if isinstance(encoder_out, tuple):
 			encoder_out = encoder_out[0]
 		
 		# predictor
-		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+		predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache)
 		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
 		                                                                predictor_outs[2], predictor_outs[3]
 		pre_token_length = pre_token_length.round().long()
 		if torch.max(pre_token_length) < 1:
 			return []
-		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
-		                                               pre_token_length)
+		decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
+		                                                     encoder_out_lens,
+		                                                     pre_acoustic_embeds,
+		                                                     pre_token_length,
+		                                                     cache=cache
+		                                                     )
 		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-		
+
 		results = []
 		b, n, d = decoder_out.size()
+		if isinstance(key[0], (list, tuple)):
+			key = key[0]
 		for i in range(b):
 			x = encoder_out[i, :encoder_out_lens[i], :]
 			am_scores = decoder_out[i, :pre_token_length[i], :]
@@ -451,9 +475,11 @@
 					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
 				)
 				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-			for hyp in nbest_hyps:
-				assert isinstance(hyp, (Hypothesis)), type(hyp)
-				
+			for nbest_idx, hyp in enumerate(nbest_hyps):
+				ibest_writer = None
+				if ibest_writer is None and kwargs.get("output_dir") is not None:
+					writer = DatadirWriter(kwargs.get("output_dir"))
+					ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
 				# remove sos/eos and get results
 				last_pos = -1
 				if isinstance(hyp.yseq, list):
@@ -462,15 +488,76 @@
 					token_int = hyp.yseq[1:last_pos].tolist()
 				
 				# remove blank symbol id, which is assumed to be 0
-				token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+				token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
 				
-				# Change integer-ids to tokens
-				token = tokenizer.ids2tokens(token_int)
-				text = tokenizer.tokens2text(token)
-				
-				timestamp = []
-				
-				results.append((text, token, timestamp))
+				if tokenizer is not None:
+					# Change integer-ids to tokens
+					token = tokenizer.ids2tokens(token_int)
+					text = tokenizer.tokens2text(token)
+					
+					text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+					
+					result_i = {"key": key[i], "text": text_postprocessed}
+					
+					if ibest_writer is not None:
+						ibest_writer["token"][key[i]] = " ".join(token)
+						# ibest_writer["text"][key[i]] = text
+						ibest_writer["text"][key[i]] = text_postprocessed
+				else:
+					result_i = {"key": key[i], "token_int": token_int}
+				results.append(result_i)
 		
 		return results
+	
+	def generate(self,
+	             data_in,
+	             data_lengths=None,
+	             key: list = None,
+	             tokenizer=None,
+	             frontend=None,
+	             **kwargs,
+	             ):
+
+		# init beamsearch
+		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		if self.beam_search is None and (is_use_lm or is_use_ctc):
+			logging.info("enable beam_search")
+			self.init_beam_search(**kwargs)
+			self.nbest = kwargs.get("nbest", 1)
+		
+		cache = kwargs.get("cache", {})
+		if len(cache) == 0:
+			self.init_cache(cache, **kwargs)
+		
+		meta_data = {}
+		chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+		chunk_stride_samples = chunk_size[1] * 960  # 600ms
+		
+		time1 = time.perf_counter()
+		audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+		                                                data_type=kwargs.get("data_type", "sound"),
+		                                                tokenizer=tokenizer)
+		time2 = time.perf_counter()
+		meta_data["load_data"] = f"{time2 - time1:0.3f}"
+		assert len(audio_sample_list) == 1, "batch_size must be set 1"
+		
+		audio_sample = cache["prev_samples"] + audio_sample_list[0]
+		
+		n = len(audio_sample) // chunk_stride_samples
+		m = len(audio_sample) % chunk_stride_samples
+		for i in range(n):
+			audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
+
+			# extract fbank feats
+			speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
+			                                       frontend=frontend, cache=cache["frontend"])
+			time3 = time.perf_counter()
+			meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+			meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+			
+			result_i = self.generate_chunk(speech, speech_lengths, **kwargs)
+		
+		cache["prev_samples"] = audio_sample[:-m]
+
 
diff --git a/runtime/python/onnxruntime/setup.py b/runtime/python/onnxruntime/setup.py
index b8dc3e1..3c128ee 100644
--- a/runtime/python/onnxruntime/setup.py
+++ b/runtime/python/onnxruntime/setup.py
@@ -13,7 +13,7 @@
 
 
 MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.2.4'
+VERSION_NUM = '0.2.5'
 
 setuptools.setup(
     name=MODULE_NAME,

--
Gitblit v1.9.1