From 9624eba825069e64a64fb40dc01df51063e9271f Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 27 四月 2023 10:46:18 +0800
Subject: [PATCH] Update asr_inference_paraformer_streaming.py
---
funasr/bin/asr_inference_paraformer_streaming.py | 25 +++++++++++++++++++------
1 files changed, 19 insertions(+), 6 deletions(-)
diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index c70baf0..ff8bb8c 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -8,6 +8,7 @@
import codecs
import tempfile
import requests
+import yaml
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -462,13 +463,23 @@
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
+ def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
+ if not Path(yaml_path).exists():
+ raise FileExistsError(f'The {yaml_path} does not exist.')
+
+ with open(str(yaml_path), 'rb') as f:
+ data = yaml.load(f, Loader=yaml.Loader)
+ return data
+
def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
if len(cache) > 0:
return cache
-
- cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
+ config = _read_yaml(asr_train_config)
+ enc_output_size = config["encoder_conf"]["output_size"]
+ feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
+ cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False}
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
cache["encoder"] = cache_en
cache_de = {"decode_fsmn": None}
@@ -478,9 +489,12 @@
def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
if len(cache) > 0:
- cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
+ config = _read_yaml(asr_train_config)
+ enc_output_size = config["encoder_conf"]["output_size"]
+ feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
+ cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False}
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
cache["encoder"] = cache_en
cache_de = {"decode_fsmn": None}
@@ -720,4 +734,3 @@
#
# rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
# print(rec_result)
-
--
Gitblit v1.9.1