From 806a03609df033d61f824f1ab8527eb88fe837ad Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 十二月 2023 19:43:13 +0800
Subject: [PATCH] funasr2 paraformer biciparaformer contextuaparaformer

---
 funasr/cli/train_cli.py |    8 ++++----
 1 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/funasr/cli/train_cli.py b/funasr/cli/train_cli.py
index c62153e..a22d5d4 100644
--- a/funasr/cli/train_cli.py
+++ b/funasr/cli/train_cli.py
@@ -35,8 +35,9 @@
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
 	import pdb; pdb.set_trace()
-	if kwargs.get("model_pretrain"):
-		kwargs = download_model(**kwargs)
+	if ":" in kwargs["model"]:
+		logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+		kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
 	
 	import pdb;
 	pdb.set_trace()
@@ -84,8 +85,7 @@
 	# init_param
 	init_param = kwargs.get("init_param", None)
 	if init_param is not None:
-		init_param = init_param
-		if isinstance(init_param, Sequence):
+		if not isinstance(init_param, Sequence):
 			init_param = (init_param,)
 		logging.info("init_param is not None: %s", init_param)
 		for p in init_param:

--
Gitblit v1.9.1