雾聪
2024-03-14 0cf5dfec2c8313fc2ed2aab8d10bf3dc4b9c283f
funasr/train_utils/trainer.py
@@ -163,7 +163,7 @@
                self.scaler.load_state_dict(checkpoint['scaler_state'])
            print(f"Checkpoint loaded successfully from '{ckpt}'")
        else:
            print(f"No checkpoint found at '{ckpt}', starting from scratch")
            print(f"No checkpoint found at '{ckpt}', does not resume status!")
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
@@ -401,4 +401,6 @@
                                                   epoch * len(self.dataloader_val) + batch_idx)
                        for key, var in speed_stats.items():
                            self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
                                                   epoch * len(self.dataloader_val) + batch_idx)
                                                   epoch * len(self.dataloader_val) + batch_idx)
        self.model.train()