嘉渊
2023-04-23 ccd4c4d240af6414c86af606e6ad9a01ac52e991
update
1个文件已修改
26 ■■■■■ 已修改文件
funasr/build_utils/build_trainer.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_trainer.py
@@ -128,8 +128,7 @@
        """Reserved for future development of another Trainer"""
        pass
    @staticmethod
    def resume(
    def resume(self,
            checkpoint: Union[str, Path],
            model: torch.nn.Module,
            reporter: Reporter,
@@ -800,3 +799,26 @@
            if distributed:
                iterator_stop.fill_(1)
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
def build_trainer(
        args,
        model: FunASRModel,
        optimizers: Sequence[torch.optim.Optimizer],
        schedulers: Sequence[Optional[AbsScheduler]],
        train_dataloader: AbsIterFactory,
        valid_dataloader: AbsIterFactory,
        trainer_options,
        distributed_option: DistributedOption
):
    trainer = Trainer(
        args=args,
        model=model,
        optimizers=optimizers,
        schedulers=schedulers,
        train_dataloader=train_dataloader,
        valid_dataloader=valid_dataloader,
        trainer_options=trainer_options,
        distributed_option=distributed_option
    )
    return trainer