Merge pull request #391 from alibaba-damo-academy/dev_wjm3
fix model save bug
| | |
| | | if num_batch_updates % batch_interval == 0: |
| | | if options.use_pai and options.oss_bucket is not None: |
| | | buffer = BytesIO() |
| | | torch.save(model.state_dict(), buffer) |
| | | if hasattr(model, "module"): |
| | | torch.save(model.module.state_dict(), buffer) |
| | | else: |
| | | torch.save(model.state_dict(), buffer) |
| | | options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue()) |
| | | else: |
| | | torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) |
| | | if hasattr(model, "module"): |
| | | torch.save(model.module.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) |
| | | else: |
| | | torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb")) |
| | | |
| | | if distributed: |
| | | torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) |