From c3442d9566f5a2011c95b0d2998958a1b5348564 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期五, 12 一月 2024 18:04:42 +0800
Subject: [PATCH] update device

---
 funasr/models/paraformer_streaming/model.py |  174 ++++++++++++++++++++++++++++++++++++++++++++++-----------
 1 files changed, 139 insertions(+), 35 deletions(-)

diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index 498d363..b736aa9 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -64,8 +64,8 @@
 		
 		super().__init__(*args, **kwargs)
 		
-		import pdb;
-		pdb.set_trace()
+		# import pdb;
+		# pdb.set_trace()
 		self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
 
 
@@ -375,11 +375,10 @@
 		
 		return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
 	
-	def calc_predictor_chunk(self, encoder_out, cache=None):
-		
-		pre_acoustic_embeds, pre_token_length = \
-			self.predictor.forward_chunk(encoder_out, cache["encoder"])
-		return pre_acoustic_embeds, pre_token_length
+	def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
+		is_final = kwargs.get("is_final", False)
+
+		return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
 	
 	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
 		decoder_outs = self.decoder(
@@ -389,48 +388,72 @@
 		decoder_out = torch.log_softmax(decoder_out, dim=-1)
 		return decoder_out, ys_pad_lens
 	
-	def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+	def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
 		decoder_outs = self.decoder.forward_chunk(
 			encoder_out, sematic_embeds, cache["decoder"]
 		)
 		decoder_out = decoder_outs
 		decoder_out = torch.log_softmax(decoder_out, dim=-1)
-		return decoder_out
+		return decoder_out, ys_pad_lens
+	
+	def init_cache(self, cache: dict = {}, **kwargs):
+		chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+		encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
+		decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
+		batch_size = 1
 
-	def generate(self,
-	             speech: torch.Tensor,
-	             speech_lengths: torch.Tensor,
-	             tokenizer=None,
-	             **kwargs,
-	             ):
+		enc_output_size = kwargs["encoder_conf"]["output_size"]
+		feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
+		cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
+		            "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
+		            "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
+		            "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+		            "tail_chunk": False}
+		cache["encoder"] = cache_encoder
 		
-		is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
-		print(is_use_ctc)
-		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
+		            "chunk_size": chunk_size}
+		cache["decoder"] = cache_decoder
+		cache["frontend"] = {}
+		cache["prev_samples"] = torch.empty(0)
 		
-		if self.beam_search is None and (is_use_lm or is_use_ctc):
-			logging.info("enable beam_search")
-			self.init_beam_search(speech, speech_lengths, **kwargs)
-			self.nbest = kwargs.get("nbest", 1)
+		return cache
+	
+	def generate_chunk(self,
+	                   speech,
+	                   speech_lengths=None,
+	                   key: list = None,
+	                   tokenizer=None,
+	                   frontend=None,
+	                   **kwargs,
+	                   ):
+		cache = kwargs.get("cache", {})
+		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
 		
-		# Forward Encoder
-		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False))
 		if isinstance(encoder_out, tuple):
 			encoder_out = encoder_out[0]
 		
 		# predictor
-		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+		predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache, is_final=kwargs.get("is_final", False))
 		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
 		                                                                predictor_outs[2], predictor_outs[3]
 		pre_token_length = pre_token_length.round().long()
 		if torch.max(pre_token_length) < 1:
 			return []
-		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
-		                                               pre_token_length)
+		decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
+		                                                     encoder_out_lens,
+		                                                     pre_acoustic_embeds,
+		                                                     pre_token_length,
+		                                                     cache=cache
+		                                                     )
 		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-		
+
 		results = []
 		b, n, d = decoder_out.size()
+		if isinstance(key[0], (list, tuple)):
+			key = key[0]
 		for i in range(b):
 			x = encoder_out[i, :encoder_out_lens[i], :]
 			am_scores = decoder_out[i, :pre_token_length[i], :]
@@ -451,8 +474,7 @@
 					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
 				)
 				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-			for hyp in nbest_hyps:
-				assert isinstance(hyp, (Hypothesis)), type(hyp)
+			for nbest_idx, hyp in enumerate(nbest_hyps):
 				
 				# remove sos/eos and get results
 				last_pos = -1
@@ -462,15 +484,97 @@
 					token_int = hyp.yseq[1:last_pos].tolist()
 				
 				# remove blank symbol id, which is assumed to be 0
-				token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+				token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
 				
+
 				# Change integer-ids to tokens
 				token = tokenizer.ids2tokens(token_int)
-				text = tokenizer.tokens2text(token)
+				# text = tokenizer.tokens2text(token)
 				
-				timestamp = []
-				
-				results.append((text, token, timestamp))
+				result_i = token
+
+
+				results.extend(result_i)
 		
 		return results
+	
+	def generate(self,
+	             data_in,
+	             data_lengths=None,
+	             key: list = None,
+	             tokenizer=None,
+	             frontend=None,
+	             cache: dict={},
+	             **kwargs,
+	             ):
+
+		# init beamsearch
+		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		if self.beam_search is None and (is_use_lm or is_use_ctc):
+			logging.info("enable beam_search")
+			self.init_beam_search(**kwargs)
+			self.nbest = kwargs.get("nbest", 1)
+		
+
+		if len(cache) == 0:
+			self.init_cache(cache, **kwargs)
+		
+		
+		meta_data = {}
+		chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+		chunk_stride_samples = int(chunk_size[1] * 960)  # 600ms
+		
+		time1 = time.perf_counter()
+		cfg = {"is_final": kwargs.get("is_final", False)}
+		audio_sample_list = load_audio_text_image_video(data_in,
+														fs=frontend.fs,
+														audio_fs=kwargs.get("fs", 16000),
+														data_type=kwargs.get("data_type", "sound"),
+														tokenizer=tokenizer,
+														cache=cfg,
+														)
+		_is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
+		
+		time2 = time.perf_counter()
+		meta_data["load_data"] = f"{time2 - time1:0.3f}"
+		assert len(audio_sample_list) == 1, "batch_size must be set 1"
+		
+		audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
+		
+		n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
+		m = int(len(audio_sample) % chunk_stride_samples * (1-int(_is_final)))
+		tokens = []
+		for i in range(n):
+			kwargs["is_final"] = _is_final and i == n -1
+			audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
+
+			# extract fbank feats
+			speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
+			                                       frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"])
+			time3 = time.perf_counter()
+			meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+			meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+			
+			tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, frontend=frontend, **kwargs)
+			tokens.extend(tokens_i)
+			
+		text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
+		
+		result_i = {"key": key[0], "text": text_postprocessed}
+		result = [result_i]
+		
+		
+		cache["prev_samples"] = audio_sample[:-m]
+		if _is_final:
+			self.init_cache(cache, **kwargs)
+		
+		if kwargs.get("output_dir"):
+			writer = DatadirWriter(kwargs.get("output_dir"))
+			ibest_writer = writer[f"{1}best_recog"]
+			ibest_writer["token"][key[0]] = " ".join(tokens)
+			ibest_writer["text"][key[0]] = text_postprocessed
+		
+		return result, meta_data
+
 

--
Gitblit v1.9.1