From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/main_funcs/average_nbest_models.py |   20 +++++++++-----------
 1 files changed, 9 insertions(+), 11 deletions(-)

diff --git a/funasr/main_funcs/average_nbest_models.py b/funasr/main_funcs/average_nbest_models.py
index 53f9568..96e1384 100644
--- a/funasr/main_funcs/average_nbest_models.py
+++ b/funasr/main_funcs/average_nbest_models.py
@@ -8,7 +8,6 @@
 from io import BytesIO
 
 import torch
-from typeguard import check_argument_types
 from typing import Collection
 
 from funasr.train.reporter import Reporter
@@ -34,7 +33,6 @@
         nbest: Number of best model files to be averaged
         suffix: A suffix added to the averaged model file name
     """
-    assert check_argument_types()
     if isinstance(nbest, int):
         nbests = [nbest]
     else:
@@ -66,13 +64,13 @@
             elif n == 1:
                 # The averaged model is same as the best model
                 e, _ = epoch_and_values[0]
-                op = output_dir / f"{e}epoch.pth"
-                sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pth"
+                op = output_dir / f"{e}epoch.pb"
+                sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
                 if sym_op.is_symlink() or sym_op.exists():
                     sym_op.unlink()
                 sym_op.symlink_to(op.name)
             else:
-                op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pth"
+                op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
                 logging.info(
                     f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
                 )
@@ -83,12 +81,12 @@
                     if e not in _loaded:
                         if oss_bucket is None:
                             _loaded[e] = torch.load(
-                                output_dir / f"{e}epoch.pth",
+                                output_dir / f"{e}epoch.pb",
                                 map_location="cpu",
                             )
                         else:
                             buffer = BytesIO(
-                                oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pth")).read())
+                                oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
                             _loaded[e] = torch.load(buffer)
                     states = _loaded[e]
 
@@ -115,13 +113,13 @@
                 else:
                     buffer = BytesIO()
                     torch.save(avg, buffer)
-                    oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pth"),
+                    oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
                                           buffer.getvalue())
 
-        # 3. *.*.ave.pth is a symlink to the max ave model
+        # 3. *.*.ave.pb is a symlink to the max ave model
         if oss_bucket is None:
-            op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pth"
-            sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pth"
+            op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
+            sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
             if sym_op.is_symlink() or sym_op.exists():
                 sym_op.unlink()
             sym_op.symlink_to(op.name)

--
Gitblit v1.9.1