游雁
2024-06-09 1c8b46a233ac4a782d7170e20533f536761e25c4
funasr/train_utils/trainer_ds.py
@@ -167,6 +167,8 @@
        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        step_in_epoch = None if step is None else step_in_epoch
        if self.use_deepspeed:
@@ -760,6 +762,10 @@
            ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
        self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
        self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
        if self.use_ddp or self.use_fsdp or self.use_deepspeed:
            dist.barrier()
        model.train()
    def log(