zhifu gao
2024-04-28 b7ae3d52681ef4f5611b059762788af7d6a37190
funasr/bin/train.py
@@ -223,11 +223,13 @@
            torch.cuda.empty_cache()
        trainer.validate_epoch(
            model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
            model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
        )
        scheduler.step()
        trainer.step_cur_in_epoch = 0
        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
        trainer.step_in_epoch = 0
        trainer.save_checkpoint(
            epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
        )
        time2 = time.perf_counter()
        time_escaped = (time2 - time1) / 3600.0