funasr/bin/train.py
@@ -198,7 +198,7 @@ writer = None dataloader_tr, dataloader_val = None, None for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): for epoch in range(trainer.start_epoch, trainer.max_epoch): time1 = time.perf_counter() for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num): @@ -223,6 +223,7 @@ torch.cuda.empty_cache() trainer.start_data_split_i = 0 trainer.validate_epoch( model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer )