From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update

---
 funasr/train_utils/average_nbest_models.py |  137 ++++++---------------------------------------
 1 files changed, 18 insertions(+), 119 deletions(-)

diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py
index 3603a44..013a719 100644
--- a/funasr/train_utils/average_nbest_models.py
+++ b/funasr/train_utils/average_nbest_models.py
@@ -16,139 +16,38 @@
 from functools import cmp_to_key
 
 
-# @torch.no_grad()
-# def average_nbest_models(
-#     output_dir: Path,
-#     best_model_criterion: Sequence[Sequence[str]],
-#     nbest: Union[Collection[int], int],
-#     suffix: Optional[str] = None,
-#     oss_bucket=None,
-#     pai_output_dir=None,
-# ) -> None:
-#     """Generate averaged model from n-best models
-#
-#     Args:
-#         output_dir: The directory contains the model file for each epoch
-#         reporter: Reporter instance
-#         best_model_criterion: Give criterions to decide the best model.
-#             e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
-#         nbest: Number of best model files to be averaged
-#         suffix: A suffix added to the averaged model file name
-#     """
-#     if isinstance(nbest, int):
-#         nbests = [nbest]
-#     else:
-#         nbests = list(nbest)
-#     if len(nbests) == 0:
-#         warnings.warn("At least 1 nbest values are required")
-#         nbests = [1]
-#     if suffix is not None:
-#         suffix = suffix + "."
-#     else:
-#         suffix = ""
-#
-#     # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
-#     nbest_epochs = [
-#         (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
-#         for ph, k, m in best_model_criterion
-#         if reporter.has(ph, k)
-#     ]
-#
-#     _loaded = {}
-#     for ph, cr, epoch_and_values in nbest_epochs:
-#         _nbests = [i for i in nbests if i <= len(epoch_and_values)]
-#         if len(_nbests) == 0:
-#             _nbests = [1]
-#
-#         for n in _nbests:
-#             if n == 0:
-#                 continue
-#             elif n == 1:
-#                 # The averaged model is same as the best model
-#                 e, _ = epoch_and_values[0]
-#                 op = output_dir / f"{e}epoch.pb"
-#                 sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
-#                 if sym_op.is_symlink() or sym_op.exists():
-#                     sym_op.unlink()
-#                 sym_op.symlink_to(op.name)
-#             else:
-#                 op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
-#                 logging.info(
-#                     f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
-#                 )
-#
-#                 avg = None
-#                 # 2.a. Averaging model
-#                 for e, _ in epoch_and_values[:n]:
-#                     if e not in _loaded:
-#                         if oss_bucket is None:
-#                             _loaded[e] = torch.load(
-#                                 output_dir / f"{e}epoch.pb",
-#                                 map_location="cpu",
-#                             )
-#                         else:
-#                             buffer = BytesIO(
-#                                 oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
-#                             _loaded[e] = torch.load(buffer)
-#                     states = _loaded[e]
-#
-#                     if avg is None:
-#                         avg = states
-#                     else:
-#                         # Accumulated
-#                         for k in avg:
-#                             avg[k] = avg[k] + states[k]
-#                 for k in avg:
-#                     if str(avg[k].dtype).startswith("torch.int"):
-#                         # For int type, not averaged, but only accumulated.
-#                         # e.g. BatchNorm.num_batches_tracked
-#                         # (If there are any cases that requires averaging
-#                         #  or the other reducing method, e.g. max/min, for integer type,
-#                         #  please report.)
-#                         pass
-#                     else:
-#                         avg[k] = avg[k] / n
-#
-#                 # 2.b. Save the ave model and create a symlink
-#                 if oss_bucket is None:
-#                     torch.save(avg, op)
-#                 else:
-#                     buffer = BytesIO()
-#                     torch.save(avg, buffer)
-#                     oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
-#                                           buffer.getvalue())
-#
-#         # 3. *.*.ave.pb is a symlink to the max ave model
-#         if oss_bucket is None:
-#             op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
-#             sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
-#             if sym_op.is_symlink() or sym_op.exists():
-#                 sym_op.unlink()
-#             sym_op.symlink_to(op.name)
-
 
 def _get_checkpoint_paths(output_dir: str, last_n: int=5):
     """
     Get the paths of the last 'last_n' checkpoints by parsing filenames
     in the output directory.
     """
-    # List all files in the output directory
-    files = os.listdir(output_dir)
-    # Filter out checkpoint files and extract epoch numbers
-    checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
-    # Sort files by epoch number in descending order
-    checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
-    # Get the last 'last_n' checkpoint paths
-    checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
+    try:
+        checkpoint = torch.load(os.path.exists(os.path.join(output_dir, "model.pt")), map_location="cpu")
+        avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
+        val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
+        sorted_items = sorted(saved_ckpts.items(), key=lambda x: x[1], reverse=True)
+        sorted_items = sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:]
+        checkpoint_paths = [os.path.join(output_dir, key) for key, value in sorted_items[:last_n]]
+    except:
+        # List all files in the output directory
+        files = os.listdir(output_dir)
+        # Filter out checkpoint files and extract epoch numbers
+        checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
+        # Sort files by epoch number in descending order
+        checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
+        # Get the last 'last_n' checkpoint paths
+        checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
     return checkpoint_paths
 
 @torch.no_grad()
-def average_checkpoints(output_dir: str, last_n: int=5, val_acc_list=[]):
+def average_checkpoints(output_dir: str, last_n: int=5, **kwargs):
     """
     Average the last 'last_n' checkpoints' model state_dicts.
     If a tensor is of type torch.int, perform sum instead of average.
     """
     checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
+    print(f"average_checkpoints: {checkpoint_paths}")
     state_dicts = []
 
     # Load state_dicts from checkpoints

--
Gitblit v1.9.1