From 2ba4683eb2ce42eec91250debe88b424cbc2d67f Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 16 三月 2023 11:14:42 +0800
Subject: [PATCH] update
---
funasr/train/trainer.py | 36 ++++++++++++++++++------------------
1 files changed, 18 insertions(+), 18 deletions(-)
diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index 50bce47..efe2009 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -205,9 +205,9 @@
else:
scaler = None
- if trainer_options.resume and (output_dir / "checkpoint.pth").exists():
+ if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
cls.resume(
- checkpoint=output_dir / "checkpoint.pth",
+ checkpoint=output_dir / "checkpoint.pb",
model=model,
optimizers=optimizers,
schedulers=schedulers,
@@ -361,7 +361,7 @@
},
buffer,
)
- trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pth"), buffer.getvalue())
+ trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"), buffer.getvalue())
else:
torch.save(
{
@@ -374,7 +374,7 @@
],
"scaler": scaler.state_dict() if scaler is not None else None,
},
- output_dir / "checkpoint.pth",
+ output_dir / "checkpoint.pb",
)
# 5. Save and log the model and update the link to the best model
@@ -382,22 +382,22 @@
buffer = BytesIO()
torch.save(model.state_dict(), buffer)
trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
- f"{iepoch}epoch.pth"),buffer.getvalue())
+ f"{iepoch}epoch.pb"),buffer.getvalue())
else:
- torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth")
+ torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
- # Creates a sym link latest.pth -> {iepoch}epoch.pth
+ # Creates a sym link latest.pb -> {iepoch}epoch.pb
if trainer_options.use_pai:
- p = os.path.join(trainer_options.output_dir, "latest.pth")
+ p = os.path.join(trainer_options.output_dir, "latest.pb")
if trainer_options.oss_bucket.object_exists(p):
trainer_options.oss_bucket.delete_object(p)
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
- os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pth"), p)
+ os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"), p)
else:
- p = output_dir / "latest.pth"
+ p = output_dir / "latest.pb"
if p.is_symlink() or p.exists():
p.unlink()
- p.symlink_to(f"{iepoch}epoch.pth")
+ p.symlink_to(f"{iepoch}epoch.pb")
_improved = []
for _phase, k, _mode in trainer_options.best_model_criterion:
@@ -407,16 +407,16 @@
# Creates sym links if it's the best result
if best_epoch == iepoch:
if trainer_options.use_pai:
- p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pth")
+ p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
if trainer_options.oss_bucket.object_exists(p):
trainer_options.oss_bucket.delete_object(p)
trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
- os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pth"),p)
+ os.path.join(trainer_options.output_dir, f"{iepoch}epoch.pb"),p)
else:
- p = output_dir / f"{_phase}.{k}.best.pth"
+ p = output_dir / f"{_phase}.{k}.best.pb"
if p.is_symlink() or p.exists():
p.unlink()
- p.symlink_to(f"{iepoch}epoch.pth")
+ p.symlink_to(f"{iepoch}epoch.pb")
_improved.append(f"{_phase}.{k}")
if len(_improved) == 0:
logging.info("There are no improvements in this epoch")
@@ -438,7 +438,7 @@
type="model",
metadata={"improved": _improved},
)
- artifact.add_file(str(output_dir / f"{iepoch}epoch.pth"))
+ artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
aliases = [
f"epoch-{iepoch}",
"best" if best_epoch == iepoch else "",
@@ -473,12 +473,12 @@
for e in range(1, iepoch):
if trainer_options.use_pai:
- p = os.path.join(trainer_options.output_dir, f"{e}epoch.pth")
+ p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
trainer_options.oss_bucket.delete_object(p)
_removed.append(str(p))
else:
- p = output_dir / f"{e}epoch.pth"
+ p = output_dir / f"{e}epoch.pb"
if p.exists() and e not in nbests:
p.unlink()
_removed.append(str(p))
--
Gitblit v1.9.1