From a75bbb028e5966ddf02aae5bea05909be9a99826 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 11 一月 2024 17:36:30 +0800
Subject: [PATCH] funasr1.0 paraformer_streaming

---
 /dev/null                                                          |   14 ----
 funasr/models/paraformer_streaming/model.py                        |   82 +++++++++++++++------------
 funasr/models/scama/sanm_encoder.py                                |    2 
 funasr/models/paraformer/cif_predictor.py                          |   11 ++-
 funasr/utils/load_utils.py                                         |    2 
 examples/industrial_data_pretraining/paraformer_streaming/demo.py  |   50 +++++++++-------
 examples/industrial_data_pretraining/paraformer_streaming/infer.sh |    2 
 7 files changed, 85 insertions(+), 78 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
index 0036e77..9923a04 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py
+++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py
@@ -3,36 +3,44 @@
 # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 #  MIT License  (https://opensource.org/licenses/MIT)
 
-# from funasr import AutoModel
-#
-# model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revison="v2.0.0")
-#
-# res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
-# print(res)
+from funasr import AutoModel
 
+chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
+encoder_chunk_look_back = 4 #number of chunks to lookback for encoder self-attention
+decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention
 
-from funasr import AutoFrontend
-
-frontend = AutoFrontend(model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0")
-
+model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0")
+cache = {}
+res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+            cache=cache,
+            is_final=True,
+            chunk_size=chunk_size,
+            encoder_chunk_look_back=encoder_chunk_look_back,
+            decoder_chunk_look_back=decoder_chunk_look_back,
+            )
+print(res)
 
 
 import soundfile
-speech, sample_rate = soundfile.read("/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/example/asr_example.wav")
+import os
 
-chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
+speech, sample_rate = soundfile.read(os.path.expanduser('~')+
+                                     "/.cache/modelscope/hub/damo/"+
+                                     "speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/"+
+                                     "example/asr_example.wav")
+
 chunk_stride = chunk_size[1] * 960 # 600ms銆�480ms
-# first chunk, 600ms
 
 cache = {}
 
 for i in range(int(len((speech)-1)/chunk_stride+1)):
     speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
-    fbanks = frontend(input=speech_chunk,
-                      batch_size=2,
-                      cache=cache)
-
-
-# for batch_idx, fbank_dict in enumerate(fbanks):
-# 	res = model(**fbank_dict)
-# 	print(res)
\ No newline at end of file
+    is_final = i == int(len((speech)-1)/chunk_stride+1)
+    res = model(input=speech_chunk,
+                cache=cache,
+                is_final=is_final,
+                chunk_size=chunk_size,
+                encoder_chunk_look_back=encoder_chunk_look_back,
+                decoder_chunk_look_back=decoder_chunk_look_back,
+                )
+    print(res)
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/finetune.sh b/examples/industrial_data_pretraining/paraformer_streaming/finetune.sh
deleted file mode 100644
index 6dca09f..0000000
--- a/examples/industrial_data_pretraining/paraformer_streaming/finetune.sh
+++ /dev/null
@@ -1,14 +0,0 @@
-
-# download model
-local_path_root=../modelscope_models
-mkdir -p ${local_path_root}
-local_path=${local_path_root}/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
-git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path}
-
-
-python funasr/bin/train.py \
-+model="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
-+token_list="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \
-+train_data_set_list="data/list/audio_datasets.jsonl" \
-+output_dir="outputs/debug/ckpt/funasr2/exp2" \
-+device="cpu"
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer_streaming/infer.sh b/examples/industrial_data_pretraining/paraformer_streaming/infer.sh
index 9436628..77e839b 100644
--- a/examples/industrial_data_pretraining/paraformer_streaming/infer.sh
+++ b/examples/industrial_data_pretraining/paraformer_streaming/infer.sh
@@ -1,5 +1,5 @@
 
-model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
 model_revision="v2.0.0"
 
 python funasr/bin/inference.py \
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 383d9ca..b06fa43 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -205,7 +205,8 @@
 
         return acoustic_embeds, token_num, alphas, cif_peak
 
-    def forward_chunk(self, hidden, cache=None):
+    def forward_chunk(self, hidden, cache=None, **kwargs):
+        is_final = kwargs.get("is_final", False)
         batch_size, len_time, hidden_size = hidden.shape
         h = hidden
         context = h.transpose(1, 2)
@@ -226,14 +227,14 @@
 
         if cache is not None and "chunk_size" in cache:
             alphas[:, :cache["chunk_size"][0]] = 0.0
-            if "is_final" in cache and not cache["is_final"]:
+            if not is_final:
                 alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
         if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
             cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
             cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
             hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
             alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
-        if cache is not None and "is_final" in cache and cache["is_final"]:
+        if cache is not None and is_final:
             tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
             tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
             tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
@@ -277,7 +278,7 @@
 
         max_token_len = max(token_length)
         if max_token_len == 0:
-             return hidden, torch.stack(token_length, 0)
+             return hidden, torch.stack(token_length, 0), None, None
         list_ls = []
         for b in range(batch_size):
             pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device)
@@ -291,7 +292,7 @@
         cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0)
         cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0)
         cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0)
-        return torch.stack(list_ls, 0), torch.stack(token_length, 0)
+        return torch.stack(list_ls, 0), torch.stack(token_length, 0), None, None
 
 
     def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index 304c0f7..927b091 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -64,8 +64,8 @@
 		
 		super().__init__(*args, **kwargs)
 		
-		import pdb;
-		pdb.set_trace()
+		# import pdb;
+		# pdb.set_trace()
 		self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
 
 
@@ -375,11 +375,10 @@
 		
 		return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
 	
-	def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None):
-		
-		pre_acoustic_embeds, pre_token_length = \
-			self.predictor.forward_chunk(encoder_out, cache["encoder"])
-		return pre_acoustic_embeds, pre_token_length
+	def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
+		is_final = kwargs.get("is_final", False)
+
+		return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
 	
 	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
 		decoder_outs = self.decoder(
@@ -416,7 +415,7 @@
 		            "chunk_size": chunk_size}
 		cache["decoder"] = cache_decoder
 		cache["frontend"] = {}
-		cache["prev_samples"] = []
+		cache["prev_samples"] = torch.empty(0)
 		
 		return cache
 	
@@ -432,12 +431,12 @@
 		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
 		
 		# Encoder
-		encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache)
+		encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False))
 		if isinstance(encoder_out, tuple):
 			encoder_out = encoder_out[0]
 		
 		# predictor
-		predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache)
+		predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache, is_final=kwargs.get("is_final", False))
 		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
 		                                                                predictor_outs[2], predictor_outs[3]
 		pre_token_length = pre_token_length.round().long()
@@ -476,10 +475,7 @@
 				)
 				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
 			for nbest_idx, hyp in enumerate(nbest_hyps):
-				ibest_writer = None
-				if ibest_writer is None and kwargs.get("output_dir") is not None:
-					writer = DatadirWriter(kwargs.get("output_dir"))
-					ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+				
 				# remove sos/eos and get results
 				last_pos = -1
 				if isinstance(hyp.yseq, list):
@@ -490,22 +486,15 @@
 				# remove blank symbol id, which is assumed to be 0
 				token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
 				
-				if tokenizer is not None:
-					# Change integer-ids to tokens
-					token = tokenizer.ids2tokens(token_int)
-					text = tokenizer.tokens2text(token)
-					
-					text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-					
-					result_i = {"key": key[i], "text": text_postprocessed}
-					
-					if ibest_writer is not None:
-						ibest_writer["token"][key[i]] = " ".join(token)
-						# ibest_writer["text"][key[i]] = text
-						ibest_writer["text"][key[i]] = text_postprocessed
-				else:
-					result_i = {"key": key[i], "token_int": token_int}
-				results.append(result_i)
+
+				# Change integer-ids to tokens
+				token = tokenizer.ids2tokens(token_int)
+				# text = tokenizer.tokens2text(token)
+				
+				result_i = token
+
+
+				results.extend(result_i)
 		
 		return results
 	
@@ -515,6 +504,7 @@
 	             key: list = None,
 	             tokenizer=None,
 	             frontend=None,
+	             cache: dict={},
 	             **kwargs,
 	             ):
 
@@ -526,9 +516,10 @@
 			self.init_beam_search(**kwargs)
 			self.nbest = kwargs.get("nbest", 1)
 		
-		cache = kwargs.get("cache", {})
+
 		if len(cache) == 0:
 			self.init_cache(cache, **kwargs)
+		_is_final = kwargs.get("is_final", False)
 		
 		meta_data = {}
 		chunk_size = kwargs.get("chunk_size", [0, 10, 5])
@@ -542,22 +533,41 @@
 		meta_data["load_data"] = f"{time2 - time1:0.3f}"
 		assert len(audio_sample_list) == 1, "batch_size must be set 1"
 		
-		audio_sample = cache["prev_samples"] + audio_sample_list[0]
+		audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
 		
-		n = len(audio_sample) // chunk_stride_samples
-		m = len(audio_sample) % chunk_stride_samples
+		n = len(audio_sample) // chunk_stride_samples + int(_is_final)
+		m = len(audio_sample) % chunk_stride_samples * (1-int(_is_final))
+		tokens = []
 		for i in range(n):
+			kwargs["is_final"] = _is_final and i == n -1
 			audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
 
 			# extract fbank feats
 			speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
-			                                       frontend=frontend, cache=cache["frontend"])
+			                                       frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"])
 			time3 = time.perf_counter()
 			meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
 			meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
 			
-			result_i = self.generate_chunk(speech, speech_lengths, **kwargs)
+			tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, frontend=frontend, **kwargs)
+			tokens.extend(tokens_i)
+			
+		text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
+		
+		result_i = {"key": key[0], "text": text_postprocessed}
+		result = [result_i]
+		
 		
 		cache["prev_samples"] = audio_sample[:-m]
+		if _is_final:
+			self.init_cache(cache, **kwargs)
+		
+		if kwargs.get("output_dir"):
+			writer = DatadirWriter(kwargs.get("output_dir"))
+			ibest_writer = writer[f"{1}best_recog"]
+			ibest_writer["token"][key[0]] = " ".join(tokens)
+			ibest_writer["text"][key[0]] = text_postprocessed
+		
+		return result, meta_data
 
 
diff --git a/funasr/models/scama/sanm_encoder.py b/funasr/models/scama/sanm_encoder.py
index 4bf6ef0..5e28db7 100644
--- a/funasr/models/scama/sanm_encoder.py
+++ b/funasr/models/scama/sanm_encoder.py
@@ -423,7 +423,9 @@
                       xs_pad: torch.Tensor,
                       ilens: torch.Tensor,
                       cache: dict = None,
+                      **kwargs,
                       ):
+        is_final = kwargs.get("is_final", False)
         xs_pad *= self.output_size() ** 0.5
         if self.embed is None:
             xs_pad = xs_pad
diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index 39b708a..bb9cf01 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -43,7 +43,7 @@
 	elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
 		data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
 	elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point
-		data_or_path_or_list = np.squeeze(data_or_path_or_list)  # [n_samples,]
+		data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze()  # [n_samples,]
 	else:
 		pass
 		# print(f"unsupport data type: {data_or_path_or_list}, return raw data")

--
Gitblit v1.9.1