funasr/bin/train.py
@@ -205,7 +205,6 @@ dataloader_tr, dataloader_val = dataloader.build_iter( epoch, data_split_i=data_split_i, start_step=trainer.start_step ) trainer.start_step = 0 trainer.train_epoch( model=model, @@ -218,7 +217,9 @@ writer=writer, data_split_i=data_split_i, data_split_num=dataloader.data_split_num, start_step=trainer.start_step, ) trainer.start_step = 0 torch.cuda.empty_cache()