From f4710b4180d83328b48f97a8463ba1a2285c429e Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 28 四月 2023 11:24:43 +0800
Subject: [PATCH] Merge pull request #440 from alibaba-damo-academy/dev_lhn
---
funasr/bin/asr_inference_paraformer_streaming.py | 32 +++++++++++++++++++++++++++-----
1 files changed, 27 insertions(+), 5 deletions(-)
diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py
index ff8bb8c..4aae8e9 100644
--- a/funasr/bin/asr_inference_paraformer_streaming.py
+++ b/funasr/bin/asr_inference_paraformer_streaming.py
@@ -20,6 +20,7 @@
import numpy as np
import torch
+import torchaudio
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
@@ -515,6 +516,8 @@
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
raw_inputs = torch.tensor(raw_inputs)
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+ raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
@@ -531,13 +534,32 @@
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
raw_inputs = torch.unsqueeze(raw_inputs, axis=0)
- input_lens = torch.tensor([raw_inputs.shape[1]])
asr_result_list = []
-
cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
- cache["encoder"]["is_final"] = is_final
- asr_result = speech2text(cache, raw_inputs, input_lens)
- item = {'key': "utt", 'value': asr_result}
+ item = {}
+ if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
+ sample_offset = 0
+ speech_length = raw_inputs.shape[1]
+ stride_size = chunk_size[1] * 960
+ cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1)
+ final_result = ""
+ for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
+ if sample_offset + stride_size >= speech_length - 1:
+ stride_size = speech_length - sample_offset
+ cache["encoder"]["is_final"] = True
+ else:
+ cache["encoder"]["is_final"] = False
+ input_lens = torch.tensor([stride_size])
+ asr_result = speech2text(cache, raw_inputs[:, sample_offset: sample_offset + stride_size], input_lens)
+ if len(asr_result) != 0:
+ final_result += asr_result[0]
+ item = {'key': "utt", 'value': [final_result]}
+ else:
+ input_lens = torch.tensor([raw_inputs.shape[1]])
+ cache["encoder"]["is_final"] = is_final
+ asr_result = speech2text(cache, raw_inputs, input_lens)
+ item = {'key': "utt", 'value': asr_result}
+
asr_result_list.append(item)
if is_final:
cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1)
--
Gitblit v1.9.1