| | |
| | | """Reserved for future development of another Trainer""" |
| | | pass |
| | | |
| | | @staticmethod |
| | | def resume( |
| | | def resume(self, |
| | | checkpoint: Union[str, Path], |
| | | model: torch.nn.Module, |
| | | reporter: Reporter, |
| | |
| | | if distributed: |
| | | iterator_stop.fill_(1) |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | | |
| | | |
| | | def build_trainer( |
| | | args, |
| | | model: FunASRModel, |
| | | optimizers: Sequence[torch.optim.Optimizer], |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | train_dataloader: AbsIterFactory, |
| | | valid_dataloader: AbsIterFactory, |
| | | trainer_options, |
| | | distributed_option: DistributedOption |
| | | ): |
| | | trainer = Trainer( |
| | | args=args, |
| | | model=model, |
| | | optimizers=optimizers, |
| | | schedulers=schedulers, |
| | | train_dataloader=train_dataloader, |
| | | valid_dataloader=valid_dataloader, |
| | | trainer_options=trainer_options, |
| | | distributed_option=distributed_option |
| | | ) |
| | | return trainer |