zhifu gao
2024-04-25 fc68b5ffe453235294a561737d8e84bb6c1689a4
funasr/bin/train.py
@@ -99,7 +99,7 @@
    if freeze_param is not None:
        if "," in freeze_param:
            freeze_param = eval(freeze_param)
        if isinstance(freeze_param, Sequence):
        if not isinstance(freeze_param, Sequence):
            freeze_param = (freeze_param,)
        logging.info("freeze_param is not None: %s", freeze_param)
        for t in freeze_param:
@@ -107,8 +107,9 @@
                if k.startswith(t + ".") or k == t:
                    logging.info(f"Setting {k}.requires_grad = False")
                    p.requires_grad = False
    if local_rank == 0:
        logging.info(f"{model_summary(model)}")
    logging.info(f"model info: {model_summary(model)}")
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(
@@ -145,8 +146,6 @@
    else:
        model = model.to(device=kwargs.get("device", "cuda"))
    if local_rank == 0:
        logging.info(f"{model}")
    kwargs["device"] = next(model.parameters()).device
    # optim
@@ -182,7 +181,12 @@
    scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
    scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
    trainer.resume_checkpoint(model=model, optim=optim, scheduler=scheduler, scaler=scaler)
    trainer.resume_checkpoint(
        model=model,
        optim=optim,
        scheduler=scheduler,
        scaler=scaler,
    )
    tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
    os.makedirs(tensorboard_dir, exist_ok=True)
@@ -197,8 +201,11 @@
    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
        time1 = time.perf_counter()
        for data_split_i in range(dataloader.data_split_num):
            dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i)
        for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
            dataloader_tr, dataloader_val = dataloader.build_iter(
                epoch, data_split_i=data_split_i, start_step=trainer.start_step
            )
            trainer.start_step = 0
            trainer.train_epoch(
                model=model,
                optim=optim,
@@ -211,9 +218,8 @@
                data_split_i=data_split_i,
                data_split_num=dataloader.data_split_num,
            )
            torch.cuda.empty_cache()
            torch.cuda.empty_cache()
        trainer.validate_epoch(
            model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer