funasr/train_utils/trainer.py
@@ -401,4 +401,6 @@ epoch * len(self.dataloader_val) + batch_idx) for key, var in speed_stats.items(): self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var), epoch * len(self.dataloader_val) + batch_idx) epoch * len(self.dataloader_val) + batch_idx) self.model.train()