| | |
| | | import torch.distributed as dist |
| | | from torch.nn.parallel import DistributedDataParallel as DDP |
| | | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| | | |
| | | from funasr.utils.download_from_hub import download_model |
| | | |
| | | def preprocess_config(cfg: DictConfig): |
| | | for key, value in cfg.items(): |
| | | if value == 'None': |
| | | cfg[key] = None |
| | | |
| | | @hydra.main(config_name=None, version_base=None) |
| | | def main_hydra(kwargs: DictConfig): |
| | | import pdb; pdb.set_trace() |
| | | if kwargs.get("model_pretrain"): |
| | | kwargs = download_model(**kwargs) |
| | | |
| | | import pdb; |
| | | pdb.set_trace() |
| | | main(**kwargs) |
| | | |
| | | |
| | | @hydra.main() |
| | | def main(kwargs: DictConfig): |
| | | def main(**kwargs): |
| | | # preprocess_config(kwargs) |
| | | # import pdb; pdb.set_trace() |
| | | # set random seed |
| | |
| | | # init_param |
| | | init_param = kwargs.get("init_param", None) |
| | | if init_param is not None: |
| | | init_param = eval(init_param) |
| | | init_param = init_param |
| | | if isinstance(init_param, Sequence): |
| | | init_param = (init_param,) |
| | | logging.info("init_param is not None: ", init_param) |
| | | logging.info("init_param is not None: %s", init_param) |
| | | for p in init_param: |
| | | logging.info(f"Loading pretrained params from {p}") |
| | | load_pretrained_model( |
| | |
| | | freeze_param = eval(freeze_param) |
| | | if isinstance(freeze_param, Sequence): |
| | | freeze_param = (freeze_param,) |
| | | logging.info("freeze_param is not None: ", freeze_param) |
| | | logging.info("freeze_param is not None: %s", freeze_param) |
| | | for t in freeze_param: |
| | | for k, p in model.named_parameters(): |
| | | if k.startswith(t + ".") or k == t: |
| | |
| | | dataloader_tr = torch.utils.data.DataLoader(dataset_tr, |
| | | collate_fn=dataset_tr.collator, |
| | | batch_sampler=batch_sampler, |
| | | num_workers=kwargs.get("num_workers", 0), |
| | | num_workers=kwargs.get("dataset_conf").get("num_workers", 4), |
| | | pin_memory=True) |
| | | |
| | | |
| | | if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0: |
| | | os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True) |
| | | yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml") |
| | | OmegaConf.save(config=kwargs, f=yaml_file) |
| | | logging.info("config.yaml is saved to: %s", yaml_file) |
| | | |
| | | trainer = Trainer( |
| | | model=model, |
| | | optim=optim, |
| | |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | | main_hydra() |