嘉渊
2023-04-23 ccd4c4d240af6414c86af606e6ad9a01ac52e991
funasr/build_utils/build_trainer.py
@@ -128,16 +128,15 @@
        """Reserved for future development of another Trainer"""
        pass
    @staticmethod
    def resume(
            checkpoint: Union[str, Path],
            model: torch.nn.Module,
            reporter: Reporter,
            optimizers: Sequence[torch.optim.Optimizer],
            schedulers: Sequence[Optional[AbsScheduler]],
            scaler: Optional[GradScaler],
            ngpu: int = 0,
    ):
    def resume(self,
               checkpoint: Union[str, Path],
               model: torch.nn.Module,
               reporter: Reporter,
               optimizers: Sequence[torch.optim.Optimizer],
               schedulers: Sequence[Optional[AbsScheduler]],
               scaler: Optional[GradScaler],
               ngpu: int = 0,
               ):
        states = torch.load(
            checkpoint,
            map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
@@ -800,3 +799,26 @@
            if distributed:
                iterator_stop.fill_(1)
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
def build_trainer(
        args,
        model: FunASRModel,
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        train_dataloader: AbsIterFactory,
        valid_dataloader: AbsIterFactory,
        trainer_options,
        distributed_option: DistributedOption
):
    trainer = Trainer(
        args=args,
        model=model,
        optimizers=optimizers,
        schedulers=schedulers,
        train_dataloader=train_dataloader,
        valid_dataloader=valid_dataloader,
        trainer_options=trainer_options,
        distributed_option=distributed_option
    )
    return trainer