From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/bin/asr_infer.py | 92 +++++++++++++++++++--------------------------
1 files changed, 39 insertions(+), 53 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 0ce8dd8..7746821 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -22,9 +22,7 @@
import requests
import torch
from packaging.version import parse as V
-from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
@@ -78,7 +76,6 @@
frontend_conf: dict = None,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -192,7 +189,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -248,7 +244,6 @@
text = None
results.append((text, token, token_int, hyp))
- assert check_return_type(results)
return results
@@ -289,7 +284,6 @@
decoding_ind: int = 0,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -405,7 +399,7 @@
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
- begin_time: int = 0, end_time: int = None,
+ decoding_ind: int = None, begin_time: int = 0, end_time: int = None,
):
"""Inference
@@ -415,7 +409,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -436,7 +429,9 @@
batch = to_device(batch, device=self.device)
# b. Forward Encoder
- enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
+ if decoding_ind is None:
+ decoding_ind = self.decoding_ind
+ enc, enc_len = self.asr_model.encode(**batch, ind=decoding_ind)
if isinstance(enc, tuple):
enc = enc[0]
# assert len(enc) == 1, len(enc)
@@ -522,7 +517,6 @@
vad_offset=begin_time)
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
- # assert check_return_type(results)
return results
def generate_hotwords_list(self, hotword_list_or_file):
@@ -662,7 +656,6 @@
hotword_list_or_file: str = None,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -782,7 +775,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
results = []
cache_en = cache["encoder"]
if speech.shape[1] < 16 * 60 and cache_en["is_final"]:
@@ -877,7 +869,6 @@
results.append(postprocessed_result)
- # assert check_return_type(results)
return results
@@ -918,7 +909,6 @@
frontend_conf: dict = None,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -1042,7 +1032,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -1110,7 +1099,6 @@
text = None
results.append((text, token, token_int, hyp))
- assert check_return_type(results)
return results
@@ -1149,7 +1137,6 @@
streaming: bool = False,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -1254,7 +1241,6 @@
text, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -1304,7 +1290,6 @@
text = None
results.append((text, token, token_int, hyp))
- assert check_return_type(results)
return results
@@ -1353,6 +1338,7 @@
nbest: int = 1,
streaming: bool = False,
simu_streaming: bool = False,
+ full_utt: bool = False,
chunk_size: int = 16,
left_context: int = 32,
right_context: int = 0,
@@ -1361,7 +1347,6 @@
"""Construct a Speech2Text object."""
super().__init__()
- assert check_argument_types()
asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
@@ -1448,6 +1433,7 @@
self.beam_search = beam_search
self.streaming = streaming
self.simu_streaming = simu_streaming
+ self.full_utt = full_utt
self.chunk_size = max(chunk_size, 0)
self.left_context = left_context
self.right_context = max(right_context, 0)
@@ -1467,6 +1453,7 @@
self._ctx = self.asr_model.encoder.get_encoder_input_size(
self.window_size
)
+ self._right_ctx = right_context
self.last_chunk_length = (
self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
@@ -1540,7 +1527,6 @@
Returns:
nbest_hypothesis: N-best hypothesis.
"""
- assert check_argument_types()
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -1565,7 +1551,7 @@
return nbest_hyps
@torch.no_grad()
- def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+ def full_utt_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
"""Speech2Text call.
Args:
speech: Speech data. (S)
@@ -1573,6 +1559,36 @@
nbest_hypothesis: N-best hypothesis.
"""
assert check_argument_types()
+
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ if self.frontend is not None:
+ speech = torch.unsqueeze(speech, axis=0)
+ speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ if self.asr_model.normalize is not None:
+ feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+ enc_out = self.asr_model.encoder.full_utt_forward(feats, feats_lengths)
+ nbest_hyps = self.beam_search(enc_out[0])
+
+ return nbest_hyps
+
+ @torch.no_grad()
+ def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
+ """Speech2Text call.
+ Args:
+ speech: Speech data. (S)
+ Returns:
+ nbest_hypothesis: N-best hypothesis.
+ """
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
@@ -1614,35 +1630,8 @@
text = None
results.append((text, token, token_int, hyp))
- assert check_return_type(results)
return results
-
- @staticmethod
- def from_pretrained(
- model_tag: Optional[str] = None,
- **kwargs: Optional[Any],
- ) -> Speech2Text:
- """Build Speech2Text instance from the pretrained model.
- Args:
- model_tag: Model tag of the pretrained models.
- Return:
- : Speech2Text instance.
- """
- if model_tag is not None:
- try:
- from espnet_model_zoo.downloader import ModelDownloader
-
- except ImportError:
- logging.error(
- "`espnet_model_zoo` is not installed. "
- "Please install via `pip install -U espnet_model_zoo`."
- )
- raise
- d = ModelDownloader()
- kwargs.update(**d.download_and_unpack(model_tag))
-
- return Speech2TextTransducer(**kwargs)
class Speech2TextSAASR:
@@ -1681,7 +1670,6 @@
frontend_conf: dict = None,
**kwargs,
):
- assert check_argument_types()
# 1. Build ASR model
scorers = {}
@@ -1799,7 +1787,6 @@
text, text_id, token, token_int, hyp
"""
- assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
@@ -1892,5 +1879,4 @@
results.append((text, text_id, token, token_int, hyp))
- assert check_return_type(results)
return results
--
Gitblit v1.9.1