From 73a7cf596bd1ebc3bd0c9674a21072f9eaf4cc57 Mon Sep 17 00:00:00 2001
From: dyyzhmm <dyyzhmm@163.com>
Date: 星期四, 16 三月 2023 15:24:56 +0800
Subject: [PATCH] Merge pull request #3 from alibaba-damo-academy/main

---
 funasr/main_funcs/average_nbest_models.py |   18 +++++++++---------
 1 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/funasr/main_funcs/average_nbest_models.py b/funasr/main_funcs/average_nbest_models.py
index 53f9568..d8df949 100644
--- a/funasr/main_funcs/average_nbest_models.py
+++ b/funasr/main_funcs/average_nbest_models.py
@@ -66,13 +66,13 @@
             elif n == 1:
                 # The averaged model is same as the best model
                 e, _ = epoch_and_values[0]
-                op = output_dir / f"{e}epoch.pth"
-                sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pth"
+                op = output_dir / f"{e}epoch.pb"
+                sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
                 if sym_op.is_symlink() or sym_op.exists():
                     sym_op.unlink()
                 sym_op.symlink_to(op.name)
             else:
-                op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pth"
+                op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
                 logging.info(
                     f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
                 )
@@ -83,12 +83,12 @@
                     if e not in _loaded:
                         if oss_bucket is None:
                             _loaded[e] = torch.load(
-                                output_dir / f"{e}epoch.pth",
+                                output_dir / f"{e}epoch.pb",
                                 map_location="cpu",
                             )
                         else:
                             buffer = BytesIO(
-                                oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pth")).read())
+                                oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
                             _loaded[e] = torch.load(buffer)
                     states = _loaded[e]
 
@@ -115,13 +115,13 @@
                 else:
                     buffer = BytesIO()
                     torch.save(avg, buffer)
-                    oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pth"),
+                    oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
                                           buffer.getvalue())
 
-        # 3. *.*.ave.pth is a symlink to the max ave model
+        # 3. *.*.ave.pb is a symlink to the max ave model
         if oss_bucket is None:
-            op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pth"
-            sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pth"
+            op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
+            sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
             if sym_op.is_symlink() or sym_op.exists():
                 sym_op.unlink()
             sym_op.symlink_to(op.name)

--
Gitblit v1.9.1