zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/bin/train.py
@@ -34,16 +34,18 @@
from funasr.utils.misc import prepare_model_dir
from funasr import AutoModel
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
    if kwargs.get("debug", False):
        import pdb; pdb.set_trace()
        import pdb
        pdb.set_trace()
    assert "model" in kwargs
    if "model_conf" not in kwargs:
        logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
    main(**kwargs)
@@ -58,15 +60,15 @@
    # open tf32
    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
    
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if local_rank == 0:
        tables.print()
    # Check if we are using DDP or FSDP
    use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
    use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1
    use_fsdp = kwargs.get("use_fsdp", False)
    # use_ddp = False if use_fsdp else use_fsdp
    if use_ddp or use_fsdp:
        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://")
        torch.cuda.set_device(local_rank)
    logging.info("Build model, frontend, tokenizer")
@@ -74,9 +76,13 @@
    kwargs["device"] = "cpu"
    model = AutoModel(**kwargs)
    
    # save config.yaml
    if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
    if (
        (use_ddp or use_fsdp)
        and dist.get_rank() == 0
        or not (use_ddp or use_fsdp)
        and local_rank == 0
    ):
        prepare_model_dir(**kwargs)
    
    # parse kwargs
@@ -101,11 +107,15 @@
                    logging.info(f"Setting {k}.requires_grad = False")
                    p.requires_grad = False
    
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(model, device_ids=[local_rank],
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
        model = DDP(
            model,
            device_ids=[local_rank],
            find_unused_parameters=kwargs.get("train_conf", {}).get(
                "find_unused_parameters", False
            ),
        )
    elif use_fsdp:
        # model = FSDP(model).cuda(local_rank)
@@ -124,10 +134,12 @@
        # Configure a custom `min_num_params`
        my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
        torch.cuda.set_device(local_rank)
        model = FSDP(model,
        model = FSDP(
            model,
                     auto_wrap_policy=custom_auto_wrap_policy,
                     mixed_precision=None,
                     device_id=torch.cuda.current_device())
            device_id=torch.cuda.current_device(),
        )
    else:
        model = model.to(device=kwargs.get("device", "cuda"))
@@ -149,13 +161,15 @@
    scheduler_class = scheduler_classes.get(scheduler)
    scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
    # dataset
    logging.info("Build dataloader")
    dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
    dataloader_class = tables.dataloader_classes.get(
        kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")
    )
    dataloader = dataloader_class(**kwargs)
    # dataloader_tr, dataloader_val = dataloader_class(**kwargs)
    trainer = Trainer(local_rank=local_rank,
    trainer = Trainer(
        local_rank=local_rank,
                      use_ddp=use_ddp,
                      use_fsdp=use_fsdp,
                      device=kwargs["device"],
@@ -172,11 +186,12 @@
    os.makedirs(tensorboard_dir, exist_ok=True)
    try:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None
    except:
        writer = None
    dataloader_tr, dataloader_val = None, None
    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
        time1 = time.perf_counter()
        
@@ -196,13 +211,9 @@
                                )
        
        trainer.validate_epoch(
            model=model,
            dataloader_val=dataloader_val,
            epoch=epoch,
            writer=writer
            model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
        )
        scheduler.step()
        
        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
@@ -212,15 +223,13 @@
            f"rank: {local_rank}, "
            f"time_escaped_epoch: {time_escaped:.3f} hours, "
            f"estimated to finish {trainer.max_epoch} "
            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n")
            f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n"
        )
    if trainer.rank == 0:
        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
    trainer.close()
    
if __name__ == "__main__":