From d77910eb6d171727f2350e45c31c91436c4c8891 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 十二月 2023 13:42:40 +0800
Subject: [PATCH] funasr2

---
 funasr/cli/train_cli.py |   32 +++++++++++++++++++++++---------
 1 files changed, 23 insertions(+), 9 deletions(-)

diff --git a/funasr/cli/train_cli.py b/funasr/cli/train_cli.py
index 54cd2e8..c62153e 100644
--- a/funasr/cli/train_cli.py
+++ b/funasr/cli/train_cli.py
@@ -25,17 +25,25 @@
 import torch.distributed as dist
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-
+from funasr.utils.download_from_hub import download_model
 
 def preprocess_config(cfg: DictConfig):
 	for key, value in cfg.items():
 		if value == 'None':
 			cfg[key] = None
 
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(kwargs: DictConfig):
+	import pdb; pdb.set_trace()
+	if kwargs.get("model_pretrain"):
+		kwargs = download_model(**kwargs)
+	
+	import pdb;
+	pdb.set_trace()
+	main(**kwargs)
 
 
-@hydra.main()
-def main(kwargs: DictConfig):
+def main(**kwargs):
 	# preprocess_config(kwargs)
 	# import pdb; pdb.set_trace()
 	# set random seed
@@ -76,10 +84,10 @@
 	# init_param
 	init_param = kwargs.get("init_param", None)
 	if init_param is not None:
-		init_param = eval(init_param)
+		init_param = init_param
 		if isinstance(init_param, Sequence):
 			init_param = (init_param,)
-		logging.info("init_param is not None: ", init_param)
+		logging.info("init_param is not None: %s", init_param)
 		for p in init_param:
 			logging.info(f"Loading pretrained params from {p}")
 			load_pretrained_model(
@@ -99,7 +107,7 @@
 		freeze_param = eval(freeze_param)
 		if isinstance(freeze_param, Sequence):
 			freeze_param = (freeze_param,)
-		logging.info("freeze_param is not None: ", freeze_param)
+		logging.info("freeze_param is not None: %s", freeze_param)
 		for t in freeze_param:
 			for k, p in model.named_parameters():
 				if k.startswith(t + ".") or k == t:
@@ -138,9 +146,15 @@
 	dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
 	                                            collate_fn=dataset_tr.collator,
 	                                            batch_sampler=batch_sampler,
-	                                            num_workers=kwargs.get("num_workers", 0),
+	                                            num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
 	                                            pin_memory=True)
-
+	
+	if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
+		os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
+		yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
+		OmegaConf.save(config=kwargs, f=yaml_file)
+		logging.info("config.yaml is saved to: %s", yaml_file)
+	
 	trainer = Trainer(
 	    model=model,
 	    optim=optim,
@@ -160,4 +174,4 @@
 	
 
 if __name__ == "__main__":
-	main()
\ No newline at end of file
+	main_hydra()
\ No newline at end of file

--
Gitblit v1.9.1