| | |
| | | wandb_model_log_interval: int |
| | | use_pai: bool |
| | | oss_bucket: Union[oss2.Bucket, None] |
| | | |
| | | batch_interval: int |
| | | |
| | | class Trainer: |
| | | """Trainer having a optimizer. |
| | |
| | | logging.warning("No keep_nbest_models is given. Change to [1]") |
| | | trainer_options.keep_nbest_models = [1] |
| | | keep_nbest_models = trainer_options.keep_nbest_models |
| | | |
| | | |
| | | #assert batch_interval is set and >0 |
| | | assert trainer_options.batch_interval > 0 |
| | | |
| | | output_dir = Path(trainer_options.output_dir) |
| | | reporter = Reporter() |
| | | if trainer_options.use_amp: |
| | |
| | | else: |
| | | scaler = None |
| | | |
| | | if trainer_options.resume and (output_dir / "checkpoint.pth").exists(): |
| | | if trainer_options.resume and (output_dir / "checkpoint.pb").exists(): |
| | | cls.resume( |
| | | checkpoint=output_dir / "checkpoint.pth", |
| | | checkpoint=output_dir / "checkpoint.pb", |
| | | model=model, |
| | | optimizers=optimizers, |
| | | schedulers=schedulers, |
| | |
| | | }, |
| | | buffer, |
| | | ) |
| | | trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pth"), buffer.getvalue()) |
| | | trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"), buffer.getvalue()) |
| | | else: |
| | | torch.save( |
| | | { |
| | |
| | | ], |
| | | "scaler": scaler.state_dict() if scaler is not None else None, |
| | | }, |
| | | output_dir / "checkpoint.pth", |
| | | output_dir / "checkpoint.pb", |
| | | ) |
| | | |
| | | # 5. Save and log the model and update the link to the best model |
| | |
| | | buffer = BytesIO() |
| | | torch.save(model.state_dict(), buffer) |
| | | trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, |
| | | f"{iepoch}epoch.pth"),buffer.getvalue()) |
| | | f"{iepoch}epoch.pb"),buffer.getvalue()) |
| | | else: |
| | | torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth") |
| | | torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb") |
| | | |
| | | # Creates a sym link latest.pth -> {iepoch}epoch.pth |
| | | # Creates a sym link latest.pb -> {iepoch}epoch.pb |
| | | if trainer_options.use_pai: |
| | | p = os.path.join(trainer_options.output_dir, "latest.pth") |
| | | p = os.path.join(trainer_options.output_dir, "latest.pb") |
| | | if trainer_options.oss_bucket.object_exists(p): |
| | | trainer_options.oss_bucket.delete_object(p) |
| | | trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name, |
| | | os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pth"), p) |
| | | os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"), p) |
| | | else: |
| | | p = output_dir / "latest.pth" |
| | | p = output_dir / "latest.pb" |
| | | if p.is_symlink() or p.exists(): |
| | | p.unlink() |
| | | p.symlink_to(f"{iepoch}epoch.pth") |
| | | p.symlink_to(f"{iepoch}epoch.pb") |
| | | |
| | | _improved = [] |
| | | for _phase, k, _mode in trainer_options.best_model_criterion: |
| | |
| | | # Creates sym links if it's the best result |
| | | if best_epoch == iepoch: |
| | | if trainer_options.use_pai: |
| | | p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pth") |
| | | p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb") |
| | | if trainer_options.oss_bucket.object_exists(p): |
| | | trainer_options.oss_bucket.delete_object(p) |
| | | trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name, |
| | | os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pth"),p) |
| | | os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"),p) |
| | | else: |
| | | p = output_dir / f"{_phase}.{k}.best.pth" |
| | | p = output_dir / f"{_phase}.{k}.best.pb" |
| | | if p.is_symlink() or p.exists(): |
| | | p.unlink() |
| | | p.symlink_to(f"{iepoch}epoch.pth") |
| | | p.symlink_to(f"{iepoch}epoch.pb") |
| | | _improved.append(f"{_phase}.{k}") |
| | | if len(_improved) == 0: |
| | | logging.info("There are no improvements in this epoch") |
| | |
| | | type="model", |
| | | metadata={"improved": _improved}, |
| | | ) |
| | | artifact.add_file(str(output_dir / f"{iepoch}epoch.pth")) |
| | | artifact.add_file(str(output_dir / f"{iepoch}epoch.pb")) |
| | | aliases = [ |
| | | f"epoch-{iepoch}", |
| | | "best" if best_epoch == iepoch else "", |
| | |
| | | |
| | | for e in range(1, iepoch): |
| | | if trainer_options.use_pai: |
| | | p = os.path.join(trainer_options.output_dir, f"{e}epoch.pth") |
| | | p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb") |
| | | if trainer_options.oss_bucket.object_exists(p) and e not in nbests: |
| | | trainer_options.oss_bucket.delete_object(p) |
| | | _removed.append(str(p)) |
| | | else: |
| | | p = output_dir / f"{e}epoch.pth" |
| | | p = output_dir / f"{e}epoch.pb" |
| | | if p.exists() and e not in nbests: |
| | | p.unlink() |
| | | _removed.append(str(p)) |
| | |
| | | # [For distributed] Because iteration counts are not always equals between |
| | | # processes, send stop-flag to the other processes if iterator is finished |
| | | iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") |
| | | |
| | | |
| | | #get the rank |
| | | rank = distributed_option.dist_rank |
| | | #get the num batch updates |
| | | num_batch_updates = 0 |
| | | #ouput dir |
| | | output_dir = Path(options.output_dir) |
| | | #batch interval |
| | | batch_interval = options.batch_interval |
| | | assert batch_interval > 0 |
| | | |
| | | start_time = time.perf_counter() |
| | | for iiter, (_, batch) in enumerate( |
| | | reporter.measure_iter_time(iterator, "iter_time"), 1 |
| | | ): |
| | | assert isinstance(batch, dict), type(batch) |
| | | |
| | | if rank == 0: |
| | | if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): |
| | | num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates() |
| | | if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai: |
| | | buffer = BytesIO() |
| | | torch.save(model.state_dict(), buffer) |
| | | options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue()) |
| | | |
| | | if distributed: |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | | if iterator_stop > 0: |
| | |
| | | else: |
| | | if distributed: |
| | | iterator_stop.fill_(1) |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |