| | |
| | | from funasr.build_utils.build_model import build_model |
| | | from funasr.build_utils.build_optimizer import build_optimizer |
| | | from funasr.build_utils.build_scheduler import build_scheduler |
| | | from funasr.build_utils.build_trainer import build_trainer |
| | | from funasr.text.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.model_summary import model_summary |
| | | from funasr.torch_utils.pytorch_version import pytorch_cudnn_version |
| | |
| | | else: |
| | | yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False) |
| | | |
| | | # dataloader for training/validation |
| | | train_dataloader, valid_dataloader = build_dataloader(args) |
| | | |
| | | # Trainer, including model, optimizers, etc. |
| | | trainer = build_trainer( |
| | | args=args, |
| | | model=model, |
| | | optimizers=optimizers, |
| | | schedulers=schedulers, |
| | | train_dataloader=train_dataloader, |
| | | valid_dataloader=valid_dataloader, |
| | | distributed_option=distributed_option |
| | | ) |
| | | |
| | | trainer.run() |
| | |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | train_dataloader: AbsIterFactory, |
| | | valid_dataloader: AbsIterFactory, |
| | | trainer_options, |
| | | distributed_option: DistributedOption): |
| | | self.trainer_options = self.build_options(args) |
| | | self.model = model |
| | |
| | | self.schedulers = schedulers |
| | | self.train_dataloader = train_dataloader |
| | | self.valid_dataloader = valid_dataloader |
| | | self.trainer_options = trainer_options |
| | | self.distributed_option = distributed_option |
| | | |
| | | def build_options(self, args: argparse.Namespace) -> TrainerOptions: |
| | |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | train_dataloader: AbsIterFactory, |
| | | valid_dataloader: AbsIterFactory, |
| | | trainer_options, |
| | | distributed_option: DistributedOption |
| | | ): |
| | | trainer = Trainer( |
| | |
| | | schedulers=schedulers, |
| | | train_dataloader=train_dataloader, |
| | | valid_dataloader=valid_dataloader, |
| | | trainer_options=trainer_options, |
| | | distributed_option=distributed_option |
| | | ) |
| | | return trainer |