| | |
| | | import torch |
| | | import torch.nn |
| | | import torch.optim |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.iterators.abs_iter_factory import AbsIterFactory |
| | | from funasr.main_funcs.average_nbest_models import average_nbest_models |
| | |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.recursive_op import recursive_average |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.train.distributed_utils import DistributedOption |
| | | from funasr.train.reporter import Reporter |
| | | from funasr.train.reporter import SubReporter |
| | | from funasr.utils.build_dataclass import build_dataclass |
| | | from funasr.utils.kwargs2args import kwargs2args |
| | | |
| | | if torch.distributed.is_available(): |
| | | from torch.distributed import ReduceOp |
| | |
| | | use_pai: bool |
| | | oss_bucket: Union[oss2.Bucket, None] |
| | | batch_interval: int |
| | | bias_grad_times: float |
| | | |
| | | class Trainer: |
| | | """Trainer having a optimizer. |
| | |
| | | @classmethod |
| | | def build_options(cls, args: argparse.Namespace) -> TrainerOptions: |
| | | """Build options consumed by train(), eval()""" |
| | | assert check_argument_types() |
| | | return build_dataclass(TrainerOptions, args) |
| | | |
| | | @classmethod |
| | |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | scaler: Optional[GradScaler], |
| | | ngpu: int = 0, |
| | | oss_bucket=None, |
| | | ): |
| | | states = torch.load( |
| | | checkpoint, |
| | | map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu", |
| | | ) |
| | | if oss_bucket is None: |
| | | if os.path.exists(checkpoint): |
| | | states = torch.load( |
| | | checkpoint, |
| | | map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu", |
| | | ) |
| | | |
| | | else: |
| | | return 0 |
| | | else: |
| | | if oss_bucket.object_exists(checkpoint): |
| | | buffer = BytesIO(oss_bucket.get_object(checkpoint).read()) |
| | | states = torch.load(buffer, map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",) |
| | | else: |
| | | return 0 |
| | | model.load_state_dict(states["model"]) |
| | | reporter.load_state_dict(states["reporter"]) |
| | | for optimizer, state in zip(optimizers, states["optimizers"]): |
| | |
| | | @classmethod |
| | | def run( |
| | | cls, |
| | | model: AbsESPnetModel, |
| | | model: FunASRModel, |
| | | optimizers: Sequence[torch.optim.Optimizer], |
| | | schedulers: Sequence[Optional[AbsScheduler]], |
| | | train_iter_factory: AbsIterFactory, |
| | |
| | | distributed_option: DistributedOption, |
| | | ) -> None: |
| | | """Perform training. This method performs the main process of training.""" |
| | | assert check_argument_types() |
| | | # NOTE(kamo): Don't check the type more strictly as far trainer_options |
| | | assert is_dataclass(trainer_options), type(trainer_options) |
| | | assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers)) |
| | |
| | | logging.warning("No keep_nbest_models is given. Change to [1]") |
| | | trainer_options.keep_nbest_models = [1] |
| | | keep_nbest_models = trainer_options.keep_nbest_models |
| | | |
| | | #assert batch_interval is set and >0 |
| | | assert trainer_options.batch_interval > 0 |
| | | |
| | | output_dir = Path(trainer_options.output_dir) |
| | | reporter = Reporter() |
| | |
| | | else: |
| | | scaler = None |
| | | |
| | | if trainer_options.resume and (output_dir / "checkpoint.pb").exists(): |
| | | if trainer_options.resume: |
| | | cls.resume( |
| | | checkpoint=output_dir / "checkpoint.pb", |
| | | checkpoint=os.path.join(trainer_options.output_dir, "checkpoint.pb") if trainer_options.use_pai else output_dir / "checkpoint.pb", |
| | | model=model, |
| | | optimizers=optimizers, |
| | | schedulers=schedulers, |
| | | reporter=reporter, |
| | | scaler=scaler, |
| | | ngpu=trainer_options.ngpu, |
| | | oss_bucket=trainer_options.oss_bucket if trainer_options.use_pai else None, |
| | | ) |
| | | |
| | | start_epoch = reporter.get_epoch() + 1 |
| | |
| | | for iepoch in range(start_epoch, trainer_options.max_epoch + 1): |
| | | if iepoch != start_epoch: |
| | | logging.info( |
| | | "{}/{}epoch started. Estimated time to finish: {}".format( |
| | | "{}/{}epoch started. Estimated time to finish: {} hours".format( |
| | | iepoch, |
| | | trainer_options.max_epoch, |
| | | humanfriendly.format_timespan( |
| | | (time.perf_counter() - start_time) |
| | | / (iepoch - start_epoch) |
| | | * (trainer_options.max_epoch - iepoch + 1) |
| | | ), |
| | | (time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * ( |
| | | trainer_options.max_epoch - iepoch + 1), |
| | | ) |
| | | ) |
| | | else: |
| | |
| | | ], |
| | | "scaler": scaler.state_dict() if scaler is not None else None, |
| | | "ema_model": model.encoder.ema.model.state_dict() |
| | | if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None, |
| | | if hasattr(model, "encoder") and hasattr(model.encoder, "ema") and model.encoder.ema is not None else None, |
| | | }, |
| | | buffer, |
| | | ) |
| | |
| | | options: TrainerOptions, |
| | | distributed_option: DistributedOption, |
| | | ) -> Tuple[bool, bool]: |
| | | assert check_argument_types() |
| | | |
| | | grad_noise = options.grad_noise |
| | | accum_grad = options.accum_grad |
| | |
| | | no_forward_run = options.no_forward_run |
| | | ngpu = options.ngpu |
| | | use_wandb = options.use_wandb |
| | | bias_grad_times = options.bias_grad_times |
| | | distributed = distributed_option.distributed |
| | | |
| | | if bias_grad_times != 1.0: |
| | | logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times)) |
| | | if log_interval is None: |
| | | try: |
| | | log_interval = max(len(iterator) // 20, 10) |
| | |
| | | #ouput dir |
| | | output_dir = Path(options.output_dir) |
| | | #batch interval |
| | | batch_interval = options.batch_interval |
| | | assert batch_interval > 0 |
| | | batch_interval = options.batch_interval |
| | | |
| | | start_time = time.perf_counter() |
| | | for iiter, (_, batch) in enumerate( |
| | |
| | | ): |
| | | assert isinstance(batch, dict), type(batch) |
| | | |
| | | if rank == 0: |
| | | if batch_interval > 0 and (not distributed_option.distributed or rank == 0): |
| | | if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): |
| | | num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates() |
| | | if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai: |
| | | buffer = BytesIO() |
| | | torch.save(model.state_dict(), buffer) |
| | | options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue()) |
| | | |
| | | if num_batch_updates % batch_interval == 0: |
| | | if options.use_pai and options.oss_bucket is not None: |
| | | buffer = BytesIO() |
| | | if hasattr(model, "module"): |
| | | torch.save(model.module.state_dict(), buffer) |
| | | else: |
| | | torch.save(model.state_dict(), buffer) |
| | | options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue()) |
| | | else: |
| | | if hasattr(model, "module"): |
| | | torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) |
| | | else: |
| | | torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) |
| | | |
| | | if distributed: |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | | if iterator_stop > 0: |
| | |
| | | if no_forward_run: |
| | | all_steps_are_invalid = False |
| | | continue |
| | | |
| | | if iiter == 1 and summary_writer is not None: |
| | | try: |
| | | args = kwargs2args(model.forward, batch) |
| | | except (ValueError, TypeError): |
| | | logging.warning( |
| | | "inpect.signature() is failed for the model. " |
| | | "The graph can't be added for tensorboard." |
| | | ) |
| | | else: |
| | | try: |
| | | summary_writer.add_graph(model, args, use_strict_trace=False) |
| | | except Exception: |
| | | logging.warning( |
| | | "summary_writer.add_graph() is failed for the model. " |
| | | "The graph can't be added for tensorboard." |
| | | ) |
| | | del args |
| | | |
| | | with autocast(scaler is not None): |
| | | with reporter.measure_time("forward_time"): |
| | |
| | | eta=1.0, |
| | | scale_factor=0.55, |
| | | ) |
| | | |
| | | # for contextual training |
| | | if bias_grad_times != 1.0: |
| | | # contextual related parameter names |
| | | cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"] |
| | | for name, param in model.named_parameters(): |
| | | for cr_pname in cr_pnames: |
| | | if cr_pname in name: |
| | | param.grad *= bias_grad_times |
| | | continue |
| | | |
| | | # compute the gradient norm to check if it is normal or not |
| | | grad_norm = torch.nn.utils.clip_grad_norm_( |
| | |
| | | options: TrainerOptions, |
| | | distributed_option: DistributedOption, |
| | | ) -> None: |
| | | assert check_argument_types() |
| | | ngpu = options.ngpu |
| | | no_forward_run = options.no_forward_run |
| | | distributed = distributed_option.distributed |