zhifu gao
2024-04-17 e8f80e96f99cb856423d030c7d055c302a6d3278
funasr/bin/train.py
@@ -55,6 +55,8 @@
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
    # open tf32
    torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True)
    
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    if local_rank == 0:
@@ -102,7 +104,7 @@
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(model, device_ids=[local_rank],
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True))
    elif use_fsdp:
        # model = FSDP(model).cuda(local_rank)
@@ -128,7 +130,8 @@
    else:
        model = model.to(device=kwargs.get("device", "cuda"))
    logging.info(f"{model}")
    if local_rank == 0:
        logging.info(f"{model}")
    kwargs["device"] = next(model.parameters()).device
        
    # optim
@@ -149,8 +152,8 @@
    # dataset
    logging.info("Build dataloader")
    dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle"))
    # dataloader = dataloader_class(**kwargs)
    dataloader_tr, dataloader_val = dataloader_class(**kwargs)
    dataloader = dataloader_class(**kwargs)
    # dataloader_tr, dataloader_val = dataloader_class(**kwargs)
    trainer = Trainer(local_rank=local_rank,
                      use_ddp=use_ddp,
                      use_fsdp=use_fsdp,
@@ -175,12 +178,12 @@
    # if use_ddp or use_fsdp:
    #     context = Join([model])
    # else:
    #     context = nullcontext()
    context = nullcontext()
    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
        time1 = time.perf_counter()
        with context:
            # dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
            dataloader_tr, dataloader_val = dataloader.build_iter(epoch)
            trainer.train_epoch(
                                model=model,
                                optim=optim,
@@ -191,13 +194,14 @@
                                epoch=epoch,
                                writer=writer
                                )
        with context:
            trainer.validate_epoch(
                model=model,
                dataloader_val=dataloader_val,
                epoch=epoch,
                writer=writer
            )
        scheduler.step()
        trainer.validate_epoch(
            model=model,
            dataloader_val=dataloader_val,
            epoch=epoch,
            writer=writer
        )
        
        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
@@ -212,7 +216,7 @@
    if trainer.rank == 0:
        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list)
        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
    trainer.close()