From 73a7cf596bd1ebc3bd0c9674a21072f9eaf4cc57 Mon Sep 17 00:00:00 2001
From: dyyzhmm <dyyzhmm@163.com>
Date: 星期四, 16 三月 2023 15:24:56 +0800
Subject: [PATCH] Merge pull request #3 from alibaba-damo-academy/main
---
funasr/main_funcs/average_nbest_models.py | 18 +++++++++---------
1 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/funasr/main_funcs/average_nbest_models.py b/funasr/main_funcs/average_nbest_models.py
index 53f9568..d8df949 100644
--- a/funasr/main_funcs/average_nbest_models.py
+++ b/funasr/main_funcs/average_nbest_models.py
@@ -66,13 +66,13 @@
elif n == 1:
# The averaged model is same as the best model
e, _ = epoch_and_values[0]
- op = output_dir / f"{e}epoch.pth"
- sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pth"
+ op = output_dir / f"{e}epoch.pb"
+ sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
if sym_op.is_symlink() or sym_op.exists():
sym_op.unlink()
sym_op.symlink_to(op.name)
else:
- op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pth"
+ op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
logging.info(
f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
)
@@ -83,12 +83,12 @@
if e not in _loaded:
if oss_bucket is None:
_loaded[e] = torch.load(
- output_dir / f"{e}epoch.pth",
+ output_dir / f"{e}epoch.pb",
map_location="cpu",
)
else:
buffer = BytesIO(
- oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pth")).read())
+ oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
_loaded[e] = torch.load(buffer)
states = _loaded[e]
@@ -115,13 +115,13 @@
else:
buffer = BytesIO()
torch.save(avg, buffer)
- oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pth"),
+ oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
buffer.getvalue())
- # 3. *.*.ave.pth is a symlink to the max ave model
+ # 3. *.*.ave.pb is a symlink to the max ave model
if oss_bucket is None:
- op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pth"
- sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pth"
+ op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
+ sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
if sym_op.is_symlink() or sym_op.exists():
sym_op.unlink()
sym_op.symlink_to(op.name)
--
Gitblit v1.9.1