游雁
2024-03-18 cbe2ea7e07cbf364827bd89cefc42b3f643ea3be
funasr/train_utils/trainer.py
@@ -88,6 +88,7 @@
        scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
        scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
        self.scaler = scaler
        self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
        
    
        try:
@@ -104,7 +105,7 @@
        self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
        
    
    def _save_checkpoint(self, epoch):
    def _save_checkpoint(self, epoch, step=None):
        """
        Saves a checkpoint containing the model's state, the optimizer's state,
        and the scheduler's state at the end of the given epoch. This method is
@@ -123,7 +124,11 @@
            state["scaler_state"] = self.scaler.state_dict()
        # Create output directory if it does not exist
        os.makedirs(self.output_dir, exist_ok=True)
        filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
        if step is None:
            filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
        else:
            filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
        torch.save(state, filename)
        
        print(f'\nCheckpoint saved to {filename}\n')
@@ -141,7 +146,7 @@
        """
        ckpt = os.path.join(resume_path, "model.pt")
        if os.path.isfile(ckpt):
            checkpoint = torch.load(ckpt)
            checkpoint = torch.load(ckpt, map_location="cpu")
            self.start_epoch = checkpoint['epoch'] + 1
            # self.model.load_state_dict(checkpoint['state_dict'])
            src_state = checkpoint['state_dict']
@@ -164,7 +169,8 @@
            print(f"Checkpoint loaded successfully from '{ckpt}'")
        else:
            print(f"No checkpoint found at '{ckpt}', does not resume status!")
        self.model.to(self.device)
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        
@@ -337,8 +343,10 @@
                    for key, var in speed_stats.items():
                        self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total)
            if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
                self._save_checkpoint(epoch, step=batch_idx+1)
        pbar.close()
    def _validate_epoch(self, epoch):
        """