aky15
2023-05-24 2f9685797b0c8a420574c2a459c242f90efdf3ee
support resume model from pai (#544)

Co-authored-by: aky15 <ankeyu.aky@11.17.44.249>
1个文件已修改
17 ■■■■ 已修改文件
funasr/train/trainer.py 17 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train/trainer.py
@@ -143,11 +143,23 @@
        schedulers: Sequence[Optional[AbsScheduler]],
        scaler: Optional[GradScaler],
        ngpu: int = 0,
        oss_bucket=None,
    ):
        if oss_bucket is None:
            if os.path.exists(checkpoint):
        states = torch.load(
            checkpoint,
            map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
        )
            else:
                return 0
        else:
            if oss_bucket.object_exists(checkpoint):
                buffer = BytesIO(oss_bucket.get_object(checkpoint).read())
                states = torch.load(buffer, map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",)
            else:
                return 0
        model.load_state_dict(states["model"])
        reporter.load_state_dict(states["reporter"])
        for optimizer, state in zip(optimizers, states["optimizers"]):
@@ -206,15 +218,16 @@
        else:
            scaler = None
        if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
        if trainer_options.resume:
            cls.resume(
                checkpoint=output_dir / "checkpoint.pb",
                checkpoint=os.path.join(trainer_options.output_dir, "checkpoint.pb") if trainer_options.use_pai else output_dir / "checkpoint.pb",
                model=model,
                optimizers=optimizers,
                schedulers=schedulers,
                reporter=reporter,
                scaler=scaler,
                ngpu=trainer_options.ngpu,
                oss_bucket=trainer_options.oss_bucket if trainer_options.use_pai else None,
            )
        start_epoch = reporter.get_epoch() + 1