From c20c871e9f963151fa410dd616c6b23d001ecdd2 Mon Sep 17 00:00:00 2001
From: Xian Shi <40013335+R1ckShi@users.noreply.github.com>
Date: 星期二, 04 七月 2023 19:57:04 +0800
Subject: [PATCH] Merge pull request #673 from alibaba-damo-academy/dev_clas
---
funasr/bin/asr_inference_launch.py | 55 ++++++++++++++++++++++++++-----------------------------
1 files changed, 26 insertions(+), 29 deletions(-)
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 539e823..a752f29 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -19,8 +19,8 @@
import numpy as np
import torch
import torchaudio
+import soundfile
import yaml
-from typeguard import check_argument_types
from funasr.bin.asr_infer import Speech2Text
from funasr.bin.asr_infer import Speech2TextMFCCA
@@ -35,8 +35,6 @@
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTask
-from funasr.tasks.vad import VADTask
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import asr_utils, postprocess_utils
@@ -81,7 +79,6 @@
param_dict: dict = None,
**kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@@ -241,7 +238,6 @@
param_dict: dict = None,
**kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
@@ -261,6 +257,7 @@
export_mode = param_dict.get("export_mode", False)
else:
hotword_list_or_file = None
+ clas_scale = param_dict.get('clas_scale', 1.0)
if kwargs.get("device", None) == "cpu":
ngpu = 0
@@ -293,6 +290,7 @@
penalty=penalty,
nbest=nbest,
hotword_list_or_file=hotword_list_or_file,
+ clas_scale=clas_scale,
)
speech2text = Speech2TextParaformer(**speech2text_kwargs)
@@ -482,7 +480,6 @@
param_dict: dict = None,
**kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
@@ -621,7 +618,12 @@
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
- batch_size_token_ms = batch_size_token * 60
+
+ batch_size_token_ms = batch_size_token*60
+ if speech2text.device == "cpu":
+ batch_size_token_ms = 0
+ batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
+
batch_size_token_ms_cum = 0
beg_idx = 0
for j, _ in enumerate(range(0, n)):
@@ -745,7 +747,6 @@
param_dict: dict = None,
**kwargs,
):
- assert check_argument_types()
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
@@ -860,7 +861,13 @@
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
raw_inputs = torch.tensor(raw_inputs)
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
- raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+ try:
+ raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+ except:
+ raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
+ if raw_inputs.ndim == 2:
+ raw_inputs = raw_inputs[:, 0]
+ raw_inputs = torch.tensor(raw_inputs)
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
@@ -947,7 +954,6 @@
param_dict: dict = None,
**kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@@ -1116,7 +1122,6 @@
param_dict: dict = None,
**kwargs,
):
- assert check_argument_types()
ncpu = kwargs.get("ncpu", 1)
torch.set_num_threads(ncpu)
if batch_size > 1:
@@ -1304,7 +1309,6 @@
right_context: Number of frames in right context AFTER subsampling.
display_partial_hypotheses: Whether to display partial hypotheses.
"""
- assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
@@ -1360,20 +1364,14 @@
**kwargs,
):
# 3. Build data-iterator
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(
- speech2text.asr_train_args, False
- ),
- collate_fn=ASRTask.build_collate_fn(
- speech2text.asr_train_args, False
- ),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
# 4 .Start for-loop
@@ -1460,7 +1458,6 @@
param_dict: dict = None,
**kwargs,
):
- assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
@@ -1520,18 +1517,16 @@
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=speech2text.asr_train_args,
+ data_path_and_name_and_type=data_path_and_name_and_type,
dtype=dtype,
fs=fs,
mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
)
finish_count = 0
@@ -1611,6 +1606,8 @@
return inference_mfcca(**kwargs)
elif mode == "rnnt":
return inference_transducer(**kwargs)
+ elif mode == "bat":
+ return inference_transducer(**kwargs)
elif mode == "sa_asr":
return inference_sa_asr(**kwargs)
else:
--
Gitblit v1.9.1