From d6cc6896e4d55498d6d36331b5c661579906525f Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 16:33:30 +0800
Subject: [PATCH] update
---
funasr/utils/build_asr_model.py | 13 +++++++++++--
1 files changed, 11 insertions(+), 2 deletions(-)
diff --git a/funasr/utils/build_asr_model.py b/funasr/utils/build_asr_model.py
index 2da050c..f333969 100644
--- a/funasr/utils/build_asr_model.py
+++ b/funasr/utils/build_asr_model.py
@@ -268,7 +268,7 @@
token_list=token_list,
**args.model_conf,
)
- elif args.model == "paraformer":
+ elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
# predictor
predictor_class = predictor_choices.get_class(args.predictor)
predictor = predictor_class(**args.predictor_conf)
@@ -336,9 +336,18 @@
stride_conv=stride_conv,
**args.model_conf,
)
-
+ elif args.model == "timestamp_prediction":
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ frontend=frontend,
+ encoder=encoder,
+ token_list=token_list,
+ **args.model_conf,
+ )
else:
raise NotImplementedError("Not supported model: {}".format(args.model))
if args.init is not None:
initialize(model, args.init)
+
+ return model
\ No newline at end of file
--
Gitblit v1.9.1