From d77910eb6d171727f2350e45c31c91436c4c8891 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 十二月 2023 13:42:40 +0800
Subject: [PATCH] funasr2
---
funasr/cli/train_cli.py | 32 +++++++++++++++++++++++---------
1 files changed, 23 insertions(+), 9 deletions(-)
diff --git a/funasr/cli/train_cli.py b/funasr/cli/train_cli.py
index 54cd2e8..c62153e 100644
--- a/funasr/cli/train_cli.py
+++ b/funasr/cli/train_cli.py
@@ -25,17 +25,25 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
-
+from funasr.utils.download_from_hub import download_model
def preprocess_config(cfg: DictConfig):
for key, value in cfg.items():
if value == 'None':
cfg[key] = None
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(kwargs: DictConfig):
+ import pdb; pdb.set_trace()
+ if kwargs.get("model_pretrain"):
+ kwargs = download_model(**kwargs)
+
+ import pdb;
+ pdb.set_trace()
+ main(**kwargs)
-@hydra.main()
-def main(kwargs: DictConfig):
+def main(**kwargs):
# preprocess_config(kwargs)
# import pdb; pdb.set_trace()
# set random seed
@@ -76,10 +84,10 @@
# init_param
init_param = kwargs.get("init_param", None)
if init_param is not None:
- init_param = eval(init_param)
+ init_param = init_param
if isinstance(init_param, Sequence):
init_param = (init_param,)
- logging.info("init_param is not None: ", init_param)
+ logging.info("init_param is not None: %s", init_param)
for p in init_param:
logging.info(f"Loading pretrained params from {p}")
load_pretrained_model(
@@ -99,7 +107,7 @@
freeze_param = eval(freeze_param)
if isinstance(freeze_param, Sequence):
freeze_param = (freeze_param,)
- logging.info("freeze_param is not None: ", freeze_param)
+ logging.info("freeze_param is not None: %s", freeze_param)
for t in freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
@@ -138,9 +146,15 @@
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
- num_workers=kwargs.get("num_workers", 0),
+ num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
pin_memory=True)
-
+
+ if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
+ os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
+ yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
+ OmegaConf.save(config=kwargs, f=yaml_file)
+ logging.info("config.yaml is saved to: %s", yaml_file)
+
trainer = Trainer(
model=model,
optim=optim,
@@ -160,4 +174,4 @@
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main_hydra()
\ No newline at end of file
--
Gitblit v1.9.1