游雁
2024-03-18 d3d2fe73c08ee51d3a44d7ffb7b31eff32b60404
funasr/train_utils/trainer.py
@@ -146,7 +146,7 @@
        """
        ckpt = os.path.join(resume_path, "model.pt")
        if os.path.isfile(ckpt):
            checkpoint = torch.load(ckpt)
            checkpoint = torch.load(ckpt, map_location="cpu")
            self.start_epoch = checkpoint['epoch'] + 1
            # self.model.load_state_dict(checkpoint['state_dict'])
            src_state = checkpoint['state_dict']
@@ -169,7 +169,8 @@
            print(f"Checkpoint loaded successfully from '{ckpt}'")
        else:
            print(f"No checkpoint found at '{ckpt}', does not resume status!")
        self.model.to(self.device)
        if self.use_ddp or self.use_fsdp:
            dist.barrier()