From 8dab6d184a034ca86eafa644ea0d2100aadfe27d Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 09 五月 2023 10:58:33 +0800
Subject: [PATCH] Merge pull request #473 from alibaba-damo-academy/dev_smohan
---
funasr/bin/sa_asr_inference.py | 31 ++++++++++++++++++++++---------
1 files changed, 22 insertions(+), 9 deletions(-)
diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py
index be63af1..c894f54 100644
--- a/funasr/bin/sa_asr_inference.py
+++ b/funasr/bin/sa_asr_inference.py
@@ -35,6 +35,8 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
@@ -85,6 +87,12 @@
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ if asr_train_args.frontend=='wav_frontend':
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+ else:
+ frontend_class=frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(**asr_train_args.frontend_conf).eval()
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
@@ -133,13 +141,6 @@
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
- logging.info(f"Beam_search: {beam_search}")
- logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
@@ -201,7 +202,16 @@
if isinstance(profile, np.ndarray):
profile = torch.tensor(profile)
- batch = {"speech": speech, "speech_lengths": speech_lengths}
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
@@ -308,6 +318,7 @@
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
+ mc: bool = False,
**kwargs,
):
inference_pipeline = inference_modelscope(
@@ -338,6 +349,7 @@
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
+ mc=mc,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -370,6 +382,7 @@
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
+ mc: bool = False,
param_dict: dict = None,
**kwargs,
):
@@ -437,7 +450,7 @@
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
- mc=True,
+ mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
--
Gitblit v1.9.1