游雁
2025-01-10 d4f13c2e444f972b272273bce76b05f52f5164aa
funasr/train_utils/trainer.py
@@ -161,8 +161,8 @@
            # self.step_or_epoch += 1
            state = {
                "epoch": epoch,
                'step': step,
                'total_step': self.batch_total,
                "step": step,
                "total_step": self.batch_total,
                "state_dict": model.state_dict(),
                "optimizer": optim.state_dict(),
                "scheduler": scheduler.state_dict(),
@@ -171,7 +171,6 @@
                "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
                "best_step_or_epoch": self.best_step_or_epoch,
                "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
                "step": step,
                "step_in_epoch": step_in_epoch,
                "data_split_i": kwargs.get("data_split_i", 0),
                "data_split_num": kwargs.get("data_split_num", 1),
@@ -194,9 +193,9 @@
                ckpt_name = f"model.pt.ep{epoch}.{step}"
            filename = os.path.join(self.output_dir, ckpt_name)
            torch.save(state, filename)
            logging.info(f'Checkpoint saved to {filename}')
            logging.info(f"Checkpoint saved to {filename}")
            latest = Path(os.path.join(self.output_dir, f'model.pt'))
            latest = Path(os.path.join(self.output_dir, f"model.pt"))
            torch.save(state, latest)
            if self.best_step_or_epoch == "":
@@ -332,7 +331,6 @@
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
    def train_epoch(
        self,
@@ -591,9 +589,9 @@
                time4 = time.perf_counter()
                if torch.isfinite(loss):
                    self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
                        batch_idx + 1
                    )
                    self.val_loss_avg = (
                        self.val_loss_avg * batch_idx + loss.detach().cpu().item()
                    ) / (batch_idx + 1)
                    if "acc" in stats:
                        self.val_acc_avg = (