| | |
| | | if num_batch_updates % batch_interval == 0: |
| | | if options.use_pai and options.oss_bucket is not None: |
| | | buffer = BytesIO() |
| | | if hasattr(model, "module"): |
| | | torch.save(model.module.state_dict(), buffer) |
| | | else: |
| | | torch.save(model.state_dict(), buffer) |
| | | options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue()) |
| | | else: |
| | | if hasattr(model, "module"): |
| | | torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) |
| | | else: |
| | | torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) |
| | | |
| | | if distributed: |