From ed25726d2992f8a395a8d4ed9bd5e85d231c471e Mon Sep 17 00:00:00 2001
From: root <root@localhost.localdomain>
Date: 星期五, 28 四月 2023 17:00:56 +0800
Subject: [PATCH] add offline websocket support
---
funasr/bin/asr_inference_paraformer_streaming.py | 67 ++++++++++++++++++++++++++-------
1 files changed, 52 insertions(+), 15 deletions(-)
diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index 939ffe9..4aae8e9 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -8,6 +8,7 @@
import codecs
import tempfile
import requests
+import yaml
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -19,6 +20,7 @@
import numpy as np
import torch
+import torchaudio
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
@@ -202,10 +204,12 @@
assert check_argument_types()
results = []
cache_en = cache["encoder"]
- if speech.shape[1] < 16 * 60 and cache["is_final"]:
- cache["last_chunk"] = True
- feats = cache["feats"]
+ if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
+ cache_en["tail_chunk"] = True
+ feats = cache_en["feats"]
feats_len = torch.tensor([feats.shape[1]])
+ results = self.infer(feats, feats_len, cache)
+ return results
else:
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"])
@@ -232,7 +236,7 @@
feats_len = torch.tensor([feats_chunk2.shape[1]])
results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
- return results_chunk1 + results_chunk2
+ return ["".join(results_chunk1 + results_chunk2)]
results = self.infer(feats, feats_len, cache)
@@ -460,13 +464,23 @@
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
+ def _read_yaml(yaml_path: Union[str, Path]) -> Dict:
+ if not Path(yaml_path).exists():
+ raise FileExistsError(f'The {yaml_path} does not exist.')
+
+ with open(str(yaml_path), 'rb') as f:
+ data = yaml.load(f, Loader=yaml.Loader)
+ return data
+
def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
if len(cache) > 0:
return cache
-
- cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
+ config = _read_yaml(asr_train_config)
+ enc_output_size = config["encoder_conf"]["output_size"]
+ feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
+ cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
cache["encoder"] = cache_en
cache_de = {"decode_fsmn": None}
@@ -476,9 +490,12 @@
def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1):
if len(cache) > 0:
- cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)),
+ config = _read_yaml(asr_train_config)
+ enc_output_size = config["encoder_conf"]["output_size"]
+ feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
+ cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
"cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
- "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))}
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
cache["encoder"] = cache_en
cache_de = {"decode_fsmn": None}
@@ -499,6 +516,8 @@
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
raw_inputs = torch.tensor(raw_inputs)
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+ raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
@@ -515,13 +534,32 @@
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
- input_lens = torch.tensor([raw_inputs.shape[1]])
asr_result_list = []
-
cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
- cache["encoder"]["is_final"] = is_final
- asr_result = speech2text(cache, raw_inputs, input_lens)
- item = {'key': "utt", 'value': asr_result}
+ item = {}
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+ sample_offset = 0
+ speech_length = raw_inputs.shape[1]
+ stride_size = chunk_size[1] * 960
+ cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
+ final_result = ""
+ for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
+ if sample_offset + stride_size >= speech_length - 1:
+ stride_size = speech_length - sample_offset
+ cache["encoder"]["is_final"] = True
+ else:
+ cache["encoder"]["is_final"] = False
+ input_lens = torch.tensor([stride_size])
+ asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
+ if len(asr_result) != 0:
+ final_result += asr_result[0]
+ item = {'key': "utt", 'value': [final_result]}
+ else:
+ input_lens = torch.tensor([raw_inputs.shape[1]])
+ cache["encoder"]["is_final"] = is_final
+ asr_result = speech2text(cache, raw_inputs, input_lens)
+ item = {'key': "utt", 'value': asr_result}
+
asr_result_list.append(item)
if is_final:
cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)
@@ -718,4 +756,3 @@
#
# rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
# print(rec_result)
-
--
Gitblit v1.9.1