| | |
| | | import torch |
| | | import torch.nn |
| | | import torch.optim |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.iterators.abs_iter_factory import AbsIterFactory |
| | | from funasr.main_funcs.average_nbest_models import average_nbest_models |
| | |
| | | |
| | | def build_options(self, args: argparse.Namespace) -> TrainerOptions: |
| | | """Build options consumed by train(), eval()""" |
| | | assert check_argument_types() |
| | | return build_dataclass(TrainerOptions, args) |
| | | |
| | | @classmethod |
| | |
| | | |
| | | def run(self) -> None: |
| | | """Perform training. This method performs the main process of training.""" |
| | | assert check_argument_types() |
| | | # NOTE(kamo): Don't check the type more strictly as far trainer_options |
| | | model = self.model |
| | | optimizers = self.optimizers |
| | |
| | | options: TrainerOptions, |
| | | distributed_option: DistributedOption, |
| | | ) -> Tuple[bool, bool]: |
| | | assert check_argument_types() |
| | | |
| | | grad_noise = options.grad_noise |
| | | accum_grad = options.accum_grad |
| | |
| | | options: TrainerOptions, |
| | | distributed_option: DistributedOption, |
| | | ) -> None: |
| | | assert check_argument_types() |
| | | ngpu = options.ngpu |
| | | # no_forward_run = options.no_forward_run |
| | | distributed = distributed_option.distributed |