From 40d1f80030d38b3377a95ead8837e82c67aa59f6 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 12 一月 2024 12:05:08 +0800
Subject: [PATCH] funasr1.0 streaming demo

---
 funasr/models/fsmn_vad_streaming/model.py                       |   93 +++++++++---------------------
 funasr/models/fsmn_vad/model.py                                 |    7 +-
 examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py |   31 +++++++++-
 funasr/models/paraformer_streaming/model.py                     |    8 +-
 funasr/utils/load_utils.py                                      |    3 
 5 files changed, 67 insertions(+), 75 deletions(-)

diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
index 2a157ee..01c6c39 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -4,8 +4,33 @@
 #  MIT License  (https://opensource.org/licenses/MIT)
 
 from funasr import AutoModel
+wav_file = "/Users/zhifu/funasr_github/test_local/asr_example.wav"
+chunk_size = 60000 # ms
+model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-streaming", model_revision="v2.0.0")
 
-model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.0")
+res = model(input=wav_file,
+            chunk_size=chunk_size,
+            )
+print(res)
 
-res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav")
-print(res)
\ No newline at end of file
+
+#
+# import soundfile
+# import os
+#
+# # wav_file = os.path.join(model.model_path, "example/vad_example.wav")
+# speech, sample_rate = soundfile.read(wav_file)
+#
+# chunk_stride = int(chunk_size * 16000 / 1000)
+#
+# cache = {}
+#
+# for i in range(int(len((speech)-1)/chunk_stride+1)):
+#     speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
+#     is_final = i == int(len((speech)-1)/chunk_stride+1)
+#     res = model(input=speech_chunk,
+#                 cache=cache,
+#                 is_final=is_final,
+#                 chunk_size=chunk_size,
+#                 )
+#     print(res)
diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index 1ed0773..75c6f4a 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -593,15 +593,16 @@
         results = []
         for i in range(batch_size):
             
-            if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
-                results[i] = json.dumps(results[i])
                 
             if ibest_writer is not None:
                 ibest_writer["text"][key[i]] = segments[i]
 
             result_i = {"key": key[i], "value": segments[i]}
             results.append(result_i)
- 
+            
+            if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+                results[i] = json.dumps(results[i])
+                
         return results, meta_data
 
     def DetectCommonFrames(self) -> int:
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 4c7e943..13b3f3a 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -496,7 +496,7 @@
     def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
                 is_final: bool = False
                 ):
-        if not cache:
+        if len(cache) == 0:
             self.AllResetDetection()
         self.waveform = waveform  # compute decibel for each frame
         self.ComputeDecibel()
@@ -521,13 +521,15 @@
         if is_final:
             # reset class variables and clear the dict for the next query
             self.AllResetDetection()
-        return segments, cache
+        return segments
 
     def init_cache(self, cache: dict = {}, **kwargs):
         cache["frontend"] = {}
         cache["prev_samples"] = torch.empty(0)
+        cache["encoder"] = {}
         
         return cache
+    
     def generate(self,
                  data_in,
                  data_lengths=None,
@@ -543,7 +545,7 @@
 
         meta_data = {}
         chunk_size = kwargs.get("chunk_size", 50) # 50ms
-        chunk_stride_samples = chunk_size * 16
+        chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
 
         time1 = time.perf_counter()
         cfg = {"is_final": kwargs.get("is_final", False)}
@@ -552,7 +554,7 @@
                                                         audio_fs=kwargs.get("fs", 16000),
                                                         data_type=kwargs.get("data_type", "sound"),
                                                         tokenizer=tokenizer,
-                                                        **cfg,
+                                                        cache=cfg,
                                                         )
         _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
 
@@ -562,9 +564,9 @@
 
         audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
 
-        n = len(audio_sample) // chunk_stride_samples + int(_is_final)
-        m = len(audio_sample) % chunk_stride_samples * (1 - int(_is_final))
-        tokens = []
+        n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
+        m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
+        segments = []
         for i in range(n):
             kwargs["is_final"] = _is_final and i == n - 1
             audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
@@ -576,58 +578,22 @@
             time3 = time.perf_counter()
             meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
             meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
-        meta_data = {}
-        audio_sample_list = [data_in]
-        if isinstance(data_in, torch.Tensor):  # fbank
-            speech, speech_lengths = data_in, data_lengths
-            if len(speech.shape) < 3:
-                speech = speech[None, :, :]
-            if speech_lengths is None:
-                speech_lengths = speech.shape[1]
-        else:
-            # extract fbank feats
-            time1 = time.perf_counter()
-            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
-            time2 = time.perf_counter()
-            meta_data["load_data"] = f"{time2 - time1:0.3f}"
-            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
-                                                   frontend=frontend)
-            time3 = time.perf_counter()
-            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-            meta_data[
-                "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
-        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
-
-        # b. Forward Encoder streaming
-        t_offset = 0
-        feats = speech
-        feats_len = speech_lengths.max().item()
-        waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
-        cache = kwargs.get("cache", {})
-        batch_size = kwargs.get("batch_size", 1)
-        step = min(feats_len, 6000)
-        segments = [[]] * batch_size
-
-        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
-            if t_offset + step >= feats_len - 1:
-                step = feats_len - t_offset
-                is_final = True
-            else:
-                is_final = False
+            speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+            
             batch = {
-                "feats": feats[:, t_offset:t_offset + step, :],
-                "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
-                "is_final": is_final,
-                "cache": cache
+                "feats": speech,
+                "waveform": cache["frontend"]["waveforms"],
+                "is_final": kwargs["is_final"],
+                "cache": cache["encoder"]
             }
+            segments_i = self.forward(**batch)
+            print(segments_i)
+            segments.extend(segments_i)
 
 
-            segments_part, cache = self.forward(**batch)
-            if segments_part:
-                for batch_num in range(0, batch_size):
-                    segments[batch_num] += segments_part[batch_num]
+        cache["prev_samples"] = audio_sample[:-m]
+        if _is_final:
+            self.init_cache(cache, **kwargs)
 
         ibest_writer = None
         if ibest_writer is None and kwargs.get("output_dir") is not None:
@@ -635,16 +601,15 @@
             ibest_writer = writer[f"{1}best_recog"]
 
         results = []
-        for i in range(batch_size):
-            
-            if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
-                results[i] = json.dumps(results[i])
-                
-            if ibest_writer is not None:
-                ibest_writer["text"][key[i]] = segments[i]
+        result_i = {"key": key[0], "value": segments}
+        if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+            result_i = json.dumps(result_i)
 
-            result_i = {"key": key[i], "value": segments[i]}
-            results.append(result_i)
+        results.append(result_i)
+            
+        if ibest_writer is not None:
+            ibest_writer["text"][key[0]] = segments
+
  
         return results, meta_data
 
diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py
index fdc0c93..b736aa9 100644
--- a/funasr/models/paraformer_streaming/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -523,7 +523,7 @@
 		
 		meta_data = {}
 		chunk_size = kwargs.get("chunk_size", [0, 10, 5])
-		chunk_stride_samples = chunk_size[1] * 960  # 600ms
+		chunk_stride_samples = int(chunk_size[1] * 960)  # 600ms
 		
 		time1 = time.perf_counter()
 		cfg = {"is_final": kwargs.get("is_final", False)}
@@ -532,7 +532,7 @@
 														audio_fs=kwargs.get("fs", 16000),
 														data_type=kwargs.get("data_type", "sound"),
 														tokenizer=tokenizer,
-														**cfg,
+														cache=cfg,
 														)
 		_is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
 		
@@ -542,8 +542,8 @@
 		
 		audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
 		
-		n = len(audio_sample) // chunk_stride_samples + int(_is_final)
-		m = len(audio_sample) % chunk_stride_samples * (1-int(_is_final))
+		n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
+		m = int(len(audio_sample) % chunk_stride_samples * (1-int(_is_final)))
 		tokens = []
 		for i in range(n):
 			kwargs["is_final"] = _is_final and i == n -1
diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index 638e0ac..4e131a8 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -48,7 +48,8 @@
 			pass
 		
 		# if data_in is a file or url, set is_final=True
-		kwargs["is_final"] = True
+		if "cache" in kwargs:
+			kwargs["cache"]["is_final"] = True
 	elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None:
 		data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
 	elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point

--
Gitblit v1.9.1