| | |
| | | wandb_model_log_interval: int |
| | | use_pai: bool |
| | | oss_bucket: Union[oss2.Bucket, None] |
| | | |
| | | batch_interval: int |
| | | |
| | | class Trainer: |
| | | """Trainer having a optimizer. |
| | |
| | | logging.warning("No keep_nbest_models is given. Change to [1]") |
| | | trainer_options.keep_nbest_models = [1] |
| | | keep_nbest_models = trainer_options.keep_nbest_models |
| | | |
| | | |
| | | #assert batch_interval is set and >0 |
| | | assert trainer_options.batch_interval > 0 |
| | | |
| | | output_dir = Path(trainer_options.output_dir) |
| | | reporter = Reporter() |
| | | if trainer_options.use_amp: |
| | |
| | | # [For distributed] Because iteration counts are not always equals between |
| | | # processes, send stop-flag to the other processes if iterator is finished |
| | | iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") |
| | | |
| | | |
| | | #get the rank |
| | | rank = distributed_option.dist_rank |
| | | #get the num batch updates |
| | | num_batch_updates = 0 |
| | | #ouput dir |
| | | output_dir = Path(options.output_dir) |
| | | #batch interval |
| | | batch_interval = options.batch_interval |
| | | assert batch_interval > 0 |
| | | |
| | | start_time = time.perf_counter() |
| | | for iiter, (_, batch) in enumerate( |
| | | reporter.measure_iter_time(iterator, "iter_time"), 1 |
| | | ): |
| | | assert isinstance(batch, dict), type(batch) |
| | | |
| | | |
| | | if rank == 0 and hasattr(model.module, "num_updates"): |
| | | num_batch_updates = model.module.get_num_updates() |
| | | if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai: |
| | | buffer = BytesIO() |
| | | torch.save(model.state_dict(), buffer) |
| | | options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue()) |
| | | |
| | | if distributed: |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | | if iterator_stop > 0: |
| | |
| | | else: |
| | | if distributed: |
| | | iterator_stop.fill_(1) |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |