speech_asr
2023-04-18 a1a79bbe3e971a00bc315d011a2e0764b3bc3111
update
1个文件已修改
13 ■■■■■ 已修改文件
funasr/train/trainer.py 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train/trainer.py
@@ -583,11 +583,14 @@
            if rank == 0:
                if hasattr(model, "num_updates") or (hasattr(model, "module") and hasattr(model.module, "num_updates")):
                    num_batch_updates = model.get_num_updates() if hasattr(model,"num_updates") else model.module.get_num_updates()
                if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None) and options.use_pai:
                    buffer = BytesIO()
                    torch.save(model.state_dict(), buffer)
                    options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}batch.pth"), buffer.getvalue())
                if (num_batch_updates%batch_interval == 0) and (options.oss_bucket is not None):
                    if options.use_pai:
                        buffer = BytesIO()
                        torch.save(model.state_dict(), buffer)
                        options.oss_bucket.put_object(os.path.join(output_dir, f"{num_batch_updates}step.pb"), buffer.getvalue())
                    else:
                        torch.save(model.state_dict(), os.path.join(output_dir, f"{num_batch_updates}step.pb"))
            if distributed:
                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
                if iterator_stop > 0: