funasr/train_utils/trainer_ds.py
@@ -167,6 +167,8 @@ Args: epoch (int): The epoch number at which the checkpoint is being saved. """ if self.use_ddp or self.use_fsdp: dist.barrier() step_in_epoch = None if step is None else step_in_epoch if self.use_deepspeed: @@ -760,6 +762,10 @@ ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}' self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg if self.use_ddp or self.use_fsdp or self.use_deepspeed: dist.barrier() model.train() def log(