From 30aa982bf29ceefaf52c0013c12c19adc57dea0e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 27 四月 2023 21:11:04 +0800
Subject: [PATCH] docs

---
 funasr/bin/asr_inference_paraformer_streaming.py |   35 +++++++++++++++++++++++++----------
 1 files changed, 25 insertions(+), 10 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index 939ffe9..ff8bb8c 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -8,6 +8,7 @@
 import codecs
 import tempfile
 import requests
+import yaml
 from pathlib import Path
 from typing import Optional
 from typing import Sequence
@@ -202,10 +203,12 @@
         assert check_argument_types()
         results = []
         cache_en = cache["encoder"]
-        if speech.shape[1] < 16 * 60 and cache["is_final"]:
-            cache["last_chunk"] = True
-            feats = cache["feats"]
+        if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
+            cache_en["tail_chunk"] = True
+            feats = cache_en["feats"]
             feats_len = torch.tensor([feats.shape[1]])
+            results = self.infer(feats, feats_len, cache)
+            return results
         else:
             if self.frontend is not None:
                 feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
@@ -232,7 +235,7 @@
                         feats_len = torch.tensor([feats_chunk2.shape[1]])
                         results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
 
-                        return results_chunk1 + results_chunk2
+                        return ["".join(results_chunk1 + results_chunk2)]
 
                 results = self.infer(feats, feats_len, cache)
 
@@ -460,13 +463,23 @@
         array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
         return array
 
+    def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
+        if not Path(yaml_path).exists():
+            raise FileExistsError(f'The {yaml_path} does not exist.')
+
+        with open(str(yaml_path), 'rb') as f:
+            data = yaml.load(f, Loader=yaml.Loader)
+        return data
+
     def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
         if len(cache) > 0:
             return cache
-
-        cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
+        config = _read_yaml(asr_train_config)
+        enc_output_size = config["encoder_conf"]["output_size"]
+        feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
+        cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
                     "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
-                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
+                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
         cache["encoder"] = cache_en
 
         cache_de = {"decode_fsmn": None}
@@ -476,9 +489,12 @@
 
     def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
         if len(cache) > 0:
-            cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
+            config = _read_yaml(asr_train_config)
+            enc_output_size = config["encoder_conf"]["output_size"]
+            feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
+            cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
                         "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
-                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
+                        "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
             cache["encoder"] = cache_en
 
             cache_de = {"decode_fsmn": None}
@@ -718,4 +734,3 @@
     #
     # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
     # print(rec_result)
-

--
Gitblit v1.9.1