From 3e77fd44304a67a2b2253b4e56fede9762bb8464 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 16:41:22 +0800
Subject: [PATCH] update

---
 funasr/utils/build_asr_model.py |   15 +++++++++++----
 1 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/funasr/utils/build_asr_model.py b/funasr/utils/build_asr_model.py
index 2da050c..e0275a0 100644
--- a/funasr/utils/build_asr_model.py
+++ b/funasr/utils/build_asr_model.py
@@ -210,7 +210,6 @@
 
     # frontend
     if args.input_size is None:
-        # Extract features in the model
         frontend_class = frontend_choices.get_class(args.frontend)
         if args.frontend == 'wav_frontend':
             frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
@@ -218,7 +217,6 @@
             frontend = frontend_class(**args.frontend_conf)
         input_size = frontend.output_size()
     else:
-        # Give features from data-loader
         args.frontend = None
         args.frontend_conf = {}
         frontend = None
@@ -268,7 +266,7 @@
             token_list=token_list,
             **args.model_conf,
         )
-    elif args.model == "paraformer":
+    elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
         # predictor
         predictor_class = predictor_choices.get_class(args.predictor)
         predictor = predictor_class(**args.predictor_conf)
@@ -336,9 +334,18 @@
             stride_conv=stride_conv,
             **args.model_conf,
         )
-
+    elif args.model == "timestamp_prediction":
+        model_class = model_choices.get_class(args.model)
+        model = model_class(
+            frontend=frontend,
+            encoder=encoder,
+            token_list=token_list,
+            **args.model_conf,
+        )
     else:
         raise NotImplementedError("Not supported model: {}".format(args.model))
 
     if args.init is not None:
         initialize(model, args.init)
+
+    return model
\ No newline at end of file

--
Gitblit v1.9.1