| | |
| | | import torch |
| | | import torch.nn |
| | | import torch.optim |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.iterators.abs_iter_factory import AbsIterFactory |
| | | from funasr.main_funcs.average_nbest_models import average_nbest_models |
| | |
| | | grad_clip: float |
| | | grad_clip_type: float |
| | | log_interval: Optional[int] |
| | | no_forward_run: bool |
| | | # no_forward_run: bool |
| | | use_tensorboard: bool |
| | | use_wandb: bool |
| | | # use_wandb: bool |
| | | output_dir: Union[Path, str] |
| | | max_epoch: int |
| | | max_update: int |
| | | seed: int |
| | | sharded_ddp: bool |
| | | # sharded_ddp: bool |
| | | patience: Optional[int] |
| | | keep_nbest_models: Union[int, List[int]] |
| | | nbest_averaging_interval: int |
| | |
| | | best_model_criterion: Sequence[Sequence[str]] |
| | | val_scheduler_criterion: Sequence[str] |
| | | unused_parameters: bool |
| | | wandb_model_log_interval: int |
| | | # wandb_model_log_interval: int |
| | | use_pai: bool |
| | | oss_bucket: Union[oss2.Bucket, None] |
| | | |
| | |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | train_dataloader: AbsIterFactory, |
| | | valid_dataloader: AbsIterFactory, |
| | | trainer_options, |
| | | distributed_option: DistributedOption): |
| | | self.trainer_options = self.build_options(args) |
| | | self.model = model |
| | |
| | | self.schedulers = schedulers |
| | | self.train_dataloader = train_dataloader |
| | | self.valid_dataloader = valid_dataloader |
| | | self.trainer_options = trainer_options |
| | | self.distributed_option = distributed_option |
| | | |
| | | def build_options(self, args: argparse.Namespace) -> TrainerOptions: |
| | | """Build options consumed by train(), eval()""" |
| | | assert check_argument_types() |
| | | return build_dataclass(TrainerOptions, args) |
| | | |
| | | @classmethod |
| | |
| | | |
| | | def run(self) -> None: |
| | | """Perform training. This method performs the main process of training.""" |
| | | assert check_argument_types() |
| | | # NOTE(kamo): Don't check the type more strictly as far trainer_options |
| | | model = self.model |
| | | optimizers = self.optimizers |
| | |
| | | raise RuntimeError( |
| | | "Require torch>=1.6.0 for Automatic Mixed Precision" |
| | | ) |
| | | if trainer_options.sharded_ddp: |
| | | if fairscale is None: |
| | | raise RuntimeError( |
| | | "Requiring fairscale. Do 'pip install fairscale'" |
| | | ) |
| | | scaler = fairscale.optim.grad_scaler.ShardedGradScaler() |
| | | else: |
| | | scaler = GradScaler() |
| | | # if trainer_options.sharded_ddp: |
| | | # if fairscale is None: |
| | | # raise RuntimeError( |
| | | # "Requiring fairscale. Do 'pip install fairscale'" |
| | | # ) |
| | | # scaler = fairscale.optim.grad_scaler.ShardedGradScaler() |
| | | # else: |
| | | scaler = GradScaler() |
| | | else: |
| | | scaler = None |
| | | |
| | |
| | | ) |
| | | elif isinstance(scheduler, AbsEpochStepScheduler): |
| | | scheduler.step() |
| | | if trainer_options.sharded_ddp: |
| | | for optimizer in optimizers: |
| | | if isinstance(optimizer, fairscale.optim.oss.OSS): |
| | | optimizer.consolidate_state_dict() |
| | | # if trainer_options.sharded_ddp: |
| | | # for optimizer in optimizers: |
| | | # if isinstance(optimizer, fairscale.optim.oss.OSS): |
| | | # optimizer.consolidate_state_dict() |
| | | |
| | | if not distributed_option.distributed or distributed_option.dist_rank == 0: |
| | | # 3. Report the results |
| | |
| | | if train_summary_writer is not None: |
| | | reporter.tensorboard_add_scalar(train_summary_writer, key1="train") |
| | | reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid") |
| | | if trainer_options.use_wandb: |
| | | reporter.wandb_log() |
| | | # if trainer_options.use_wandb: |
| | | # reporter.wandb_log() |
| | | |
| | | # save tensorboard on oss |
| | | if trainer_options.use_pai and train_summary_writer is not None: |
| | |
| | | "The best model has been updated: " + ", ".join(_improved) |
| | | ) |
| | | |
| | | log_model = ( |
| | | trainer_options.wandb_model_log_interval > 0 |
| | | and iepoch % trainer_options.wandb_model_log_interval == 0 |
| | | ) |
| | | if log_model and trainer_options.use_wandb: |
| | | import wandb |
| | | |
| | | logging.info("Logging Model on this epoch :::::") |
| | | artifact = wandb.Artifact( |
| | | name=f"model_{wandb.run.id}", |
| | | type="model", |
| | | metadata={"improved": _improved}, |
| | | ) |
| | | artifact.add_file(str(output_dir / f"{iepoch}epoch.pb")) |
| | | aliases = [ |
| | | f"epoch-{iepoch}", |
| | | "best" if best_epoch == iepoch else "", |
| | | ] |
| | | wandb.log_artifact(artifact, aliases=aliases) |
| | | # log_model = ( |
| | | # trainer_options.wandb_model_log_interval > 0 |
| | | # and iepoch % trainer_options.wandb_model_log_interval == 0 |
| | | # ) |
| | | # if log_model and trainer_options.use_wandb: |
| | | # import wandb |
| | | # |
| | | # logging.info("Logging Model on this epoch :::::") |
| | | # artifact = wandb.Artifact( |
| | | # name=f"model_{wandb.run.id}", |
| | | # type="model", |
| | | # metadata={"improved": _improved}, |
| | | # ) |
| | | # artifact.add_file(str(output_dir / f"{iepoch}epoch.pb")) |
| | | # aliases = [ |
| | | # f"epoch-{iepoch}", |
| | | # "best" if best_epoch == iepoch else "", |
| | | # ] |
| | | # wandb.log_artifact(artifact, aliases=aliases) |
| | | |
| | | # 6. Remove the model files excluding n-best epoch and latest epoch |
| | | _removed = [] |
| | |
| | | options: TrainerOptions, |
| | | distributed_option: DistributedOption, |
| | | ) -> Tuple[bool, bool]: |
| | | assert check_argument_types() |
| | | |
| | | grad_noise = options.grad_noise |
| | | accum_grad = options.accum_grad |
| | | grad_clip = options.grad_clip |
| | | grad_clip_type = options.grad_clip_type |
| | | log_interval = options.log_interval |
| | | no_forward_run = options.no_forward_run |
| | | # no_forward_run = options.no_forward_run |
| | | ngpu = options.ngpu |
| | | use_wandb = options.use_wandb |
| | | # use_wandb = options.use_wandb |
| | | distributed = distributed_option.distributed |
| | | |
| | | if log_interval is None: |
| | |
| | | break |
| | | |
| | | batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") |
| | | if no_forward_run: |
| | | all_steps_are_invalid = False |
| | | continue |
| | | # if no_forward_run: |
| | | # all_steps_are_invalid = False |
| | | # continue |
| | | |
| | | with autocast(scaler is not None): |
| | | with reporter.measure_time("forward_time"): |
| | |
| | | logging.info(reporter.log_message(-log_interval, num_updates=num_updates)) |
| | | if summary_writer is not None: |
| | | reporter.tensorboard_add_scalar(summary_writer, -log_interval) |
| | | if use_wandb: |
| | | reporter.wandb_log() |
| | | # if use_wandb: |
| | | # reporter.wandb_log() |
| | | |
| | | if max_update_stop: |
| | | break |
| | |
| | | options: TrainerOptions, |
| | | distributed_option: DistributedOption, |
| | | ) -> None: |
| | | assert check_argument_types() |
| | | ngpu = options.ngpu |
| | | no_forward_run = options.no_forward_run |
| | | # no_forward_run = options.no_forward_run |
| | | distributed = distributed_option.distributed |
| | | |
| | | model.eval() |
| | |
| | | break |
| | | |
| | | batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") |
| | | if no_forward_run: |
| | | continue |
| | | # if no_forward_run: |
| | | # continue |
| | | |
| | | retval = model(**batch) |
| | | if isinstance(retval, dict): |
| | |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | train_dataloader: AbsIterFactory, |
| | | valid_dataloader: AbsIterFactory, |
| | | trainer_options, |
| | | distributed_option: DistributedOption |
| | | ): |
| | | trainer = Trainer( |
| | |
| | | schedulers=schedulers, |
| | | train_dataloader=train_dataloader, |
| | | valid_dataloader=valid_dataloader, |
| | | trainer_options=trainer_options, |
| | | distributed_option=distributed_option |
| | | ) |
| | | return trainer |