zhifu gao
2024-05-14 2f27b165559cd53afab52047309ebe4ac838ebb8
funasr/bin/train.py
@@ -198,14 +198,13 @@
        writer = None
    dataloader_tr, dataloader_val = None, None
    for epoch in range(trainer.start_epoch, trainer.max_epoch + 1):
    for epoch in range(trainer.start_epoch, trainer.max_epoch):
        time1 = time.perf_counter()
        for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num):
            dataloader_tr, dataloader_val = dataloader.build_iter(
                epoch, data_split_i=data_split_i, start_step=trainer.start_step
            )
            trainer.start_step = 0
            trainer.train_epoch(
                model=model,
@@ -218,16 +217,20 @@
                writer=writer,
                data_split_i=data_split_i,
                data_split_num=dataloader.data_split_num,
                start_step=trainer.start_step,
            )
            trainer.start_step = 0
            torch.cuda.empty_cache()
        trainer.validate_epoch(
            model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
            model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
        )
        scheduler.step()
        trainer.step_cur_in_epoch = 0
        trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
        trainer.step_in_epoch = 0
        trainer.save_checkpoint(
            epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
        )
        time2 = time.perf_counter()
        time_escaped = (time2 - time1) / 3600.0