| | |
| | | output_dir = Path(options.output_dir) |
| | | #batch interval |
| | | batch_interval = options.batch_interval |
| | | assert batch_interval > 0 |
| | | |
| | | start_time = time.perf_counter() |
| | | for iiter, (_, batch) in enumerate( |
| | |
| | | ): |
| | | assert isinstance(batch, dict), type(batch) |
| | | |
| | | if rank == 0: |
| | | if batch_interval > 0 and (not distributed_option.distributed or rank == 0): |
| | | if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")): |
| | | num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates() |
| | | if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai: |
| | | if num_batch_updates % batch_interval == 0: |
| | | if options.use_pai and options.oss_bucket is not None: |
| | | buffer = BytesIO() |
| | | torch.save(model.state_dict(), buffer) |
| | | options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue()) |
| | | options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue()) |
| | | else: |
| | | torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) |
| | | |
| | | if distributed: |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |