funasr/train_utils/trainer.py
@@ -146,7 +146,7 @@ """ ckpt = os.path.join(resume_path, "model.pt") if os.path.isfile(ckpt): checkpoint = torch.load(ckpt) checkpoint = torch.load(ckpt, map_location="cpu") self.start_epoch = checkpoint['epoch'] + 1 # self.model.load_state_dict(checkpoint['state_dict']) src_state = checkpoint['state_dict'] @@ -169,7 +169,8 @@ print(f"Checkpoint loaded successfully from '{ckpt}'") else: print(f"No checkpoint found at '{ckpt}', does not resume status!") self.model.to(self.device) if self.use_ddp or self.use_fsdp: dist.barrier()