游雁
2023-12-13 806a03609df033d61f824f1ab8527eb88fe837ad
funasr/cli/train_cli.py
@@ -25,17 +25,26 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.utils.download_from_hub import download_model
def preprocess_config(cfg: DictConfig):
   for key, value in cfg.items():
      if value == 'None':
         cfg[key] = None
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
   import pdb; pdb.set_trace()
   if ":" in kwargs["model"]:
      logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
      kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
   import pdb;
   pdb.set_trace()
   main(**kwargs)
@hydra.main()
def main(kwargs: DictConfig):
def main(**kwargs):
   # preprocess_config(kwargs)
   # import pdb; pdb.set_trace()
   # set random seed
@@ -46,7 +55,7 @@
   
   local_rank = int(os.environ.get('LOCAL_RANK', 0))
   # Check if we are using DDP or FSDP
   use_ddp = 'WORLD_SIZE' in os.environ
   use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
   use_fsdp = kwargs.get("use_fsdp", None)
   if use_ddp or use_fsdp:
      dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
@@ -76,10 +85,9 @@
   # init_param
   init_param = kwargs.get("init_param", None)
   if init_param is not None:
      init_param = eval(init_param)
      if isinstance(init_param, Sequence):
      if not isinstance(init_param, Sequence):
         init_param = (init_param,)
      logging.info("init_param is not None: ", init_param)
      logging.info("init_param is not None: %s", init_param)
      for p in init_param:
         logging.info(f"Loading pretrained params from {p}")
         load_pretrained_model(
@@ -99,7 +107,7 @@
      freeze_param = eval(freeze_param)
      if isinstance(freeze_param, Sequence):
         freeze_param = (freeze_param,)
      logging.info("freeze_param is not None: ", freeze_param)
      logging.info("freeze_param is not None: %s", freeze_param)
      for t in freeze_param:
         for k, p in model.named_parameters():
            if k.startswith(t + ".") or k == t:
@@ -109,7 +117,8 @@
   if use_ddp:
      model = model.cuda(local_rank)
      model = DDP(model, device_ids=[local_rank])
      model = DDP(model, device_ids=[local_rank],
                  find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
   elif use_fsdp:
      model = FSDP(model).cuda(local_rank)
   else:
@@ -137,9 +146,15 @@
   dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
                                               collate_fn=dataset_tr.collator,
                                               batch_sampler=batch_sampler,
                                               num_workers=kwargs.get("num_workers", 0),
                                               num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
                                               pin_memory=True)
   if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
      os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
      yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
      OmegaConf.save(config=kwargs, f=yaml_file)
      logging.info("config.yaml is saved to: %s", yaml_file)
   trainer = Trainer(
       model=model,
       optim=optim,
@@ -157,13 +172,6 @@
      torch.distributed.destroy_process_group()
   
def train(epoch, model, op):
   pass
def val():
   pass
if __name__ == "__main__":
   main()
   main_hydra()