From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/train_utils/average_nbest_models.py | 173 ++++++++++++++-------------------------------------------
1 files changed, 44 insertions(+), 129 deletions(-)
diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py
index 3603a44..67f1e55 100644
--- a/funasr/train_utils/average_nbest_models.py
+++ b/funasr/train_utils/average_nbest_models.py
@@ -16,152 +16,67 @@
from functools import cmp_to_key
-# @torch.no_grad()
-# def average_nbest_models(
-# output_dir: Path,
-# best_model_criterion: Sequence[Sequence[str]],
-# nbest: Union[Collection[int], int],
-# suffix: Optional[str] = None,
-# oss_bucket=None,
-# pai_output_dir=None,
-# ) -> None:
-# """Generate averaged model from n-best models
-#
-# Args:
-# output_dir: The directory contains the model file for each epoch
-# reporter: Reporter instance
-# best_model_criterion: Give criterions to decide the best model.
-# e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
-# nbest: Number of best model files to be averaged
-# suffix: A suffix added to the averaged model file name
-# """
-# if isinstance(nbest, int):
-# nbests = [nbest]
-# else:
-# nbests = list(nbest)
-# if len(nbests) == 0:
-# warnings.warn("At least 1 nbest values are required")
-# nbests = [1]
-# if suffix is not None:
-# suffix = suffix + "."
-# else:
-# suffix = ""
-#
-# # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
-# nbest_epochs = [
-# (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
-# for ph, k, m in best_model_criterion
-# if reporter.has(ph, k)
-# ]
-#
-# _loaded = {}
-# for ph, cr, epoch_and_values in nbest_epochs:
-# _nbests = [i for i in nbests if i <= len(epoch_and_values)]
-# if len(_nbests) == 0:
-# _nbests = [1]
-#
-# for n in _nbests:
-# if n == 0:
-# continue
-# elif n == 1:
-# # The averaged model is same as the best model
-# e, _ = epoch_and_values[0]
-# op = output_dir / f"{e}epoch.pb"
-# sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
-# if sym_op.is_symlink() or sym_op.exists():
-# sym_op.unlink()
-# sym_op.symlink_to(op.name)
-# else:
-# op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
-# logging.info(
-# f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
-# )
-#
-# avg = None
-# # 2.a. Averaging model
-# for e, _ in epoch_and_values[:n]:
-# if e not in _loaded:
-# if oss_bucket is None:
-# _loaded[e] = torch.load(
-# output_dir / f"{e}epoch.pb",
-# map_location="cpu",
-# )
-# else:
-# buffer = BytesIO(
-# oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
-# _loaded[e] = torch.load(buffer)
-# states = _loaded[e]
-#
-# if avg is None:
-# avg = states
-# else:
-# # Accumulated
-# for k in avg:
-# avg[k] = avg[k] + states[k]
-# for k in avg:
-# if str(avg[k].dtype).startswith("torch.int"):
-# # For int type, not averaged, but only accumulated.
-# # e.g. BatchNorm.num_batches_tracked
-# # (If there are any cases that requires averaging
-# # or the other reducing method, e.g. max/min, for integer type,
-# # please report.)
-# pass
-# else:
-# avg[k] = avg[k] / n
-#
-# # 2.b. Save the ave model and create a symlink
-# if oss_bucket is None:
-# torch.save(avg, op)
-# else:
-# buffer = BytesIO()
-# torch.save(avg, buffer)
-# oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
-# buffer.getvalue())
-#
-# # 3. *.*.ave.pb is a symlink to the max ave model
-# if oss_bucket is None:
-# op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
-# sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
-# if sym_op.is_symlink() or sym_op.exists():
-# sym_op.unlink()
-# sym_op.symlink_to(op.name)
-
-
-def _get_checkpoint_paths(output_dir: str, last_n: int=5):
+def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs):
"""
Get the paths of the last 'last_n' checkpoints by parsing filenames
in the output directory.
"""
- # List all files in the output directory
- files = os.listdir(output_dir)
- # Filter out checkpoint files and extract epoch numbers
- checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
- # Sort files by epoch number in descending order
- checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
- # Get the last 'last_n' checkpoint paths
- checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
+ try:
+ if not use_deepspeed:
+ checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
+ else:
+ checkpoint = torch.load(
+ os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"),
+ map_location="cpu",
+ )
+ avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
+ val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
+ sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
+ sorted_items = (
+ sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:]
+ )
+ checkpoint_paths = []
+ for key, value in sorted_items[:last_n]:
+ if not use_deepspeed:
+ ckpt = os.path.join(output_dir, key)
+ else:
+ ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
+ checkpoint_paths.append(ckpt)
+
+ except:
+ print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
+ # List all files in the output directory
+ files = os.listdir(output_dir)
+ # Filter out checkpoint files and extract epoch numbers
+ checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
+ # Sort files by epoch number in descending order
+ checkpoint_files.sort(key=lambda x: int(re.search(r"(\d+)", x).group()), reverse=True)
+ # Get the last 'last_n' checkpoint paths
+ checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
return checkpoint_paths
+
@torch.no_grad()
-def average_checkpoints(output_dir: str, last_n: int=5, val_acc_list=[]):
+def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs):
"""
Average the last 'last_n' checkpoints' model state_dicts.
If a tensor is of type torch.int, perform sum instead of average.
"""
- checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
+ checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
+ print(f"average_checkpoints: {checkpoint_paths}")
state_dicts = []
# Load state_dicts from checkpoints
for path in checkpoint_paths:
if os.path.isfile(path):
- state_dicts.append(torch.load(path, map_location='cpu')['state_dict'])
+ state_dicts.append(torch.load(path, map_location="cpu")["state_dict"])
else:
print(f"Checkpoint file {path} not found.")
- continue
# Check if we have any state_dicts to average
- if not state_dicts:
- raise RuntimeError("No checkpoints found for averaging.")
+ if len(state_dicts) < 1:
+ print("No checkpoints found for averaging.")
+ return
# Average or sum weights
avg_state_dict = OrderedDict()
@@ -176,6 +91,6 @@
# Perform average for other types of tensors
stacked_tensors = torch.stack(tensors)
avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
-
- torch.save({'state_dict': avg_state_dict}, os.path.join(output_dir, f"model.pt.avg{last_n}"))
- return avg_state_dict
\ No newline at end of file
+ checkpoint_outpath = os.path.join(output_dir, f"model.pt.avg{last_n}")
+ torch.save({"state_dict": avg_state_dict}, checkpoint_outpath)
+ return checkpoint_outpath
--
Gitblit v1.9.1