嘉渊
2023-04-24 df5f263e5fe3d7961b1aeb3589012400a9905a8f
funasr/build_utils/build_trainer.py
@@ -107,7 +107,6 @@
                 schedulers: Sequence[Optional[AbsScheduler]],
                 train_dataloader: AbsIterFactory,
                 valid_dataloader: AbsIterFactory,
                 trainer_options,
                 distributed_option: DistributedOption):
        self.trainer_options = self.build_options(args)
        self.model = model
@@ -115,7 +114,6 @@
        self.schedulers = schedulers
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.trainer_options = trainer_options
        self.distributed_option = distributed_option
    def build_options(self, args: argparse.Namespace) -> TrainerOptions:
@@ -128,16 +126,15 @@
        """Reserved for future development of another Trainer"""
        pass
    @staticmethod
    def resume(
            checkpoint: Union[str, Path],
            model: torch.nn.Module,
            reporter: Reporter,
            optimizers: Sequence[torch.optim.Optimizer],
            schedulers: Sequence[Optional[AbsScheduler]],
            scaler: Optional[GradScaler],
            ngpu: int = 0,
    ):
    def resume(self,
               checkpoint: Union[str, Path],
               model: torch.nn.Module,
               reporter: Reporter,
               optimizers: Sequence[torch.optim.Optimizer],
               schedulers: Sequence[Optional[AbsScheduler]],
               scaler: Optional[GradScaler],
               ngpu: int = 0,
               ):
        states = torch.load(
            checkpoint,
            map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
@@ -800,3 +797,24 @@
            if distributed:
                iterator_stop.fill_(1)
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
def build_trainer(
        args,
        model: FunASRModel,
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        train_dataloader: AbsIterFactory,
        valid_dataloader: AbsIterFactory,
        distributed_option: DistributedOption
):
    trainer = Trainer(
        args=args,
        model=model,
        optimizers=optimizers,
        schedulers=schedulers,
        train_dataloader=train_dataloader,
        valid_dataloader=valid_dataloader,
        distributed_option=distributed_option
    )
    return trainer