From 47343b5c2f4e1256f60f46d8da0aa2e5de39b6c7 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期六, 05 八月 2023 17:53:08 +0800
Subject: [PATCH] init repo
---
funasr/bin/asr_infer.py | 68 ++++++----------------------------
1 files changed, 12 insertions(+), 56 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index e12dbb5..02ca63d 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -22,9 +22,7 @@
import requests
import torch
from packaging.version import parse as V
-from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
@@ -78,7 +76,6 @@
frontend_conf: dict = None,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -192,7 +189,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -248,7 +244,6 @@
text = None
results.append((text, token, token_int, hyp))
- assert check_return_type(results)
return results
@@ -285,10 +280,10 @@
nbest: int = 1,
frontend_conf: dict = None,
hotword_list_or_file: str = None,
+ clas_scale: float = 1.0,
decoding_ind: int = 0,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -382,6 +377,7 @@
# 6. [Optional] Build hotword list from str, local file or url
self.hotword_list = None
self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
+ self.clas_scale = clas_scale
is_use_lm = lm_weight != 0.0 and lm_file is not None
if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
@@ -413,7 +409,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -446,16 +441,20 @@
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
- if not isinstance(self.asr_model, ContextualParaformer) and not isinstance(self.asr_model,
- NeatContextualParaformer):
+ if not isinstance(self.asr_model, ContextualParaformer) and \
+ not isinstance(self.asr_model, NeatContextualParaformer):
if self.hotword_list:
logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
pre_token_length)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
else:
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds,
- pre_token_length, hw_list=self.hotword_list)
+ decoder_outs = self.asr_model.cal_decoder_with_predictor(enc,
+ enc_len,
+ pre_acoustic_embeds,
+ pre_token_length,
+ hw_list=self.hotword_list,
+ clas_scale=self.clas_scale)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
if isinstance(self.asr_model, BiCifParaformer):
@@ -516,7 +515,6 @@
vad_offset=begin_time)
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
- # assert check_return_type(results)
return results
def generate_hotwords_list(self, hotword_list_or_file):
@@ -609,7 +607,7 @@
hotword_str_list = []
for hw in hotword_list_or_file.strip().split():
hotword_str_list.append(hw)
- hw_list = hw
+ hw_list = hw.strip().split()
if seg_dict is not None:
hw_list = seg_tokenize(hw_list, seg_dict)
hotword_list.append(self.converter.tokens2ids(hw_list))
@@ -656,7 +654,6 @@
hotword_list_or_file: str = None,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -776,7 +773,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
results = []
cache_en = cache["encoder"]
if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
@@ -871,7 +867,6 @@
results.append(postprocessed_result)
- # assert check_return_type(results)
return results
@@ -912,7 +907,6 @@
frontend_conf: dict = None,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -1036,7 +1030,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -1104,7 +1097,6 @@
text = None
results.append((text, token, token_int, hyp))
- assert check_return_type(results)
return results
@@ -1143,7 +1135,6 @@
streaming: bool = False,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -1248,7 +1239,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -1298,7 +1288,6 @@
text = None
results.append((text, token, token_int, hyp))
- assert check_return_type(results)
return results
@@ -1355,7 +1344,6 @@
"""Construct a Speech2Text object."""
super().__init__()
- assert check_argument_types()
asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
@@ -1534,7 +1522,6 @@
Returns:
nbest_hypothesis: N-best hypothesis.
"""
- assert check_argument_types()
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -1566,7 +1553,6 @@
Returns:
nbest_hypothesis: N-best hypothesis.
"""
- assert check_argument_types()
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -1608,35 +1594,8 @@
text = None
results.append((text, token, token_int, hyp))
- assert check_return_type(results)
return results
-
- @staticmethod
- def from_pretrained(
- model_tag: Optional[str] = None,
- **kwargs: Optional[Any],
- ) -> Speech2Text:
- """Build Speech2Text instance from the pretrained model.
- Args:
- model_tag: Model tag of the pretrained models.
- Return:
- : Speech2Text instance.
- """
- if model_tag is not None:
- try:
- from espnet_model_zoo.downloader import ModelDownloader
-
- except ImportError:
- logging.error(
- "`espnet_model_zoo` is not installed. "
- "Please install via `pip install -U espnet_model_zoo`."
- )
- raise
- d = ModelDownloader()
- kwargs.update(**d.download_and_unpack(model_tag))
-
- return Speech2TextTransducer(**kwargs)
class Speech2TextSAASR:
@@ -1675,7 +1634,6 @@
frontend_conf: dict = None,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -1793,7 +1751,6 @@
text, text_id, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -1886,5 +1843,4 @@
results.append((text, text_id, token, token_int, hyp))
- assert check_return_type(results)
return results
--
Gitblit v1.9.1