| | |
| | | wandb_model_log_interval: int |
| | | use_pai: bool |
| | | oss_bucket: Union[oss2.Bucket, None] |
| | | |
| | | batch_interval: int |
| | | bias_grad_times: float |
| | | |
| | | class Trainer: |
| | | """Trainer having a optimizer. |
| | |
| | | logging.warning("No keep_nbest_models is given. Change to [1]") |
| | | trainer_options.keep_nbest_models = [1] |
| | | keep_nbest_models = trainer_options.keep_nbest_models |
| | | |
| | | |
| | | output_dir = Path(trainer_options.output_dir) |
| | | reporter = Reporter() |
| | | if trainer_options.use_amp: |
| | |
| | | no_forward_run = options.no_forward_run |
| | | ngpu = options.ngpu |
| | | use_wandb = options.use_wandb |
| | | bias_grad_times = options.bias_grad_times |
| | | distributed = distributed_option.distributed |
| | | |
| | | if bias_grad_times != 1.0: |
| | | logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times)) |
| | | if log_interval is None: |
| | | try: |
| | | log_interval = max(len(iterator) // 20, 10) |
| | |
| | | # [For distributed] Because iteration counts are not always equals between |
| | | # processes, send stop-flag to the other processes if iterator is finished |
| | | iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") |
| | | |
| | | |
| | | #get the rank |
| | | rank = distributed_option.dist_rank |
| | | #get the num batch updates |
| | | num_batch_updates = 0 |
| | | #ouput dir |
| | | output_dir = Path(options.output_dir) |
| | | #batch interval |
| | | batch_interval = options.batch_interval |
| | | |
| | | start_time = time.perf_counter() |
| | | for iiter, (_, batch) in enumerate( |
| | | reporter.measure_iter_time(iterator, "iter_time"), 1 |
| | | ): |
| | | assert isinstance(batch, dict), type(batch) |
| | | |
| | | if batch_interval > 0 and (not distributed_option.distributed or rank == 0): |
| | | if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): |
| | | num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates() |
| | | if num_batch_updates % batch_interval == 0: |
| | | if options.use_pai and options.oss_bucket is not None: |
| | | buffer = BytesIO() |
| | | if hasattr(model, "module"): |
| | | torch.save(model.module.state_dict(), buffer) |
| | | else: |
| | | torch.save(model.state_dict(), buffer) |
| | | options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue()) |
| | | else: |
| | | if hasattr(model, "module"): |
| | | torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) |
| | | else: |
| | | torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) |
| | | |
| | | if distributed: |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | |
| | | eta=1.0, |
| | | scale_factor=0.55, |
| | | ) |
| | | |
| | | # for contextual training |
| | | if bias_grad_times != 1.0: |
| | | # contextual related parameter names |
| | | cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"] |
| | | for name, param in model.named_parameters(): |
| | | for cr_pname in cr_pnames: |
| | | if cr_pname in name: |
| | | param.grad *= bias_grad_times |
| | | continue |
| | | |
| | | # compute the gradient norm to check if it is normal or not |
| | | grad_norm = torch.nn.utils.clip_grad_norm_( |
| | |
| | | else: |
| | | if distributed: |
| | | iterator_stop.fill_(1) |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |