From 1af8a233ce99b6c6a8a119eaa7363ebae1f2570f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 21 六月 2023 11:15:06 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/bin/asr_infer.py | 65 ++++++++++++++------------------
1 files changed, 29 insertions(+), 36 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 288034c..c722ebc 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -24,7 +24,7 @@
from packaging.version import parse as V
from typeguard import check_argument_types
from typeguard import check_return_type
-
+from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
@@ -35,9 +35,7 @@
from funasr.modules.beam_search.beam_search_transducer import Hypothesis as HypothesisTransducer
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.tasks.asr import ASRTask
-from funasr.tasks.asr import frontend_choices
-from funasr.tasks.lm import LMTask
+from funasr.build_utils.build_asr_model import frontend_choices
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
@@ -84,7 +82,7 @@
# 1. Build ASR model
scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
@@ -92,7 +90,6 @@
if asr_train_args.frontend == 'wav_frontend':
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
else:
- from funasr.tasks.asr import frontend_choices
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@@ -112,7 +109,7 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
+ lm, lm_train_args = build_model_from_file(
lm_train_config, lm_file, None, device
)
scorers["lm"] = lm.lm
@@ -295,9 +292,8 @@
# 1. Build ASR model
scorers = {}
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
+ asr_model, asr_train_args = build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -319,8 +315,8 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
scorers["lm"] = lm.lm
@@ -616,9 +612,8 @@
# 1. Build ASR model
scorers = {}
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
+ asr_model, asr_train_args = build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device, mode="paraformer"
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -640,8 +635,8 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
scorers["lm"] = lm.lm
@@ -873,9 +868,8 @@
# 1. Build ASR model
scorers = {}
- from funasr.tasks.asr import ASRTaskUniASR as ASRTask
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
+ asr_model, asr_train_args = build_model_from_file(
+ asr_train_config, asr_model_file, cmvn_file, device, mode="uniasr"
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
@@ -901,8 +895,8 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, device, "lm"
)
scorers["lm"] = lm.lm
@@ -1104,9 +1098,8 @@
assert check_argument_types()
# 1. Build ASR model
- from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
@@ -1126,8 +1119,8 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
lm.to(device)
scorers["lm"] = lm.lm
@@ -1315,8 +1308,7 @@
super().__init__()
assert check_argument_types()
- from funasr.tasks.asr import ASRTransducerTask
- asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
@@ -1350,8 +1342,8 @@
asr_model.to(dtype=getattr(torch, dtype)).eval()
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
lm_scorer = lm.lm
else:
@@ -1638,15 +1630,16 @@
assert check_argument_types()
# 1. Build ASR model
- from funasr.tasks.sa_asr import ASRTask
scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
+ asr_model, asr_train_args = build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- if asr_train_args.frontend == 'wav_frontend':
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+ from funasr.tasks.sa_asr import frontend_choices
+ if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
+ frontend_class = frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
else:
frontend_class = frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@@ -1667,8 +1660,8 @@
# 2. Build Language model
if lm_train_config is not None:
- lm, lm_train_args = LMTask.build_model_from_file(
- lm_train_config, lm_file, None, device
+ lm, lm_train_args = build_model_from_file(
+ lm_train_config, lm_file, None, device, task_name="lm"
)
scorers["lm"] = lm.lm
--
Gitblit v1.9.1