From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/main_funcs/average_nbest_models.py | 20 +++++++++-----------
1 files changed, 9 insertions(+), 11 deletions(-)
diff --git a/funasr/main_funcs/average_nbest_models.py b/funasr/main_funcs/average_nbest_models.py
index 53f9568..96e1384 100644
--- a/funasr/main_funcs/average_nbest_models.py
+++ b/funasr/main_funcs/average_nbest_models.py
@@ -8,7 +8,6 @@
from io import BytesIO
import torch
-from typeguard import check_argument_types
from typing import Collection
from funasr.train.reporter import Reporter
@@ -34,7 +33,6 @@
nbest: Number of best model files to be averaged
suffix: A suffix added to the averaged model file name
"""
- assert check_argument_types()
if isinstance(nbest, int):
nbests = [nbest]
else:
@@ -66,13 +64,13 @@
elif n == 1:
# The averaged model is same as the best model
e, _ = epoch_and_values[0]
- op = output_dir / f"{e}epoch.pth"
- sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pth"
+ op = output_dir / f"{e}epoch.pb"
+ sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
if sym_op.is_symlink() or sym_op.exists():
sym_op.unlink()
sym_op.symlink_to(op.name)
else:
- op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pth"
+ op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
logging.info(
f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
)
@@ -83,12 +81,12 @@
if e not in _loaded:
if oss_bucket is None:
_loaded[e] = torch.load(
- output_dir / f"{e}epoch.pth",
+ output_dir / f"{e}epoch.pb",
map_location="cpu",
)
else:
buffer = BytesIO(
- oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pth")).read())
+ oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
_loaded[e] = torch.load(buffer)
states = _loaded[e]
@@ -115,13 +113,13 @@
else:
buffer = BytesIO()
torch.save(avg, buffer)
- oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pth"),
+ oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
buffer.getvalue())
- # 3. *.*.ave.pth is a symlink to the max ave model
+ # 3. *.*.ave.pb is a symlink to the max ave model
if oss_bucket is None:
- op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pth"
- sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pth"
+ op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
+ sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
if sym_op.is_symlink() or sym_op.exists():
sym_op.unlink()
sym_op.symlink_to(op.name)
--
Gitblit v1.9.1