游雁
2024-03-24 ed22e34d654c47017962d3e5758d3a351d8826ab
funasr/bin/train.py
@@ -128,7 +128,8 @@
    else:
        model = model.to(device=kwargs.get("device", "cuda"))
    logging.info(f"{model}")
    if local_rank == 0:
        logging.info(f"{model}")
    kwargs["device"] = next(model.parameters()).device
        
    # optim
@@ -175,8 +176,8 @@
    # if use_ddp or use_fsdp:
    #     context = Join([model])
    # else:
    #     context = nullcontext()
    context = nullcontext()
    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
        time1 = time.perf_counter()
        with context:
@@ -191,13 +192,14 @@
                                epoch=epoch,
                                writer=writer
                                )
        with context:
            trainer.validate_epoch(
                model=model,
                dataloader_val=dataloader_val,
                epoch=epoch,
                writer=writer
            )
        scheduler.step()
        trainer.validate_epoch(
            model=model,
            dataloader_val=dataloader_val,
            epoch=epoch,
            writer=writer
        )
        
        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)