From 9a9c3b75b5b3359701844a91a9fae6d2979866cd Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 17 一月 2024 18:28:28 +0800
Subject: [PATCH] Funasr1.0 (#1261)
---
funasr/train_utils/average_nbest_models.py | 268 ++++++++++++++++++++++++++++++++---------------------
1 files changed, 162 insertions(+), 106 deletions(-)
diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py
index 96e1384..f117804 100644
--- a/funasr/train_utils/average_nbest_models.py
+++ b/funasr/train_utils/average_nbest_models.py
@@ -9,117 +9,173 @@
import torch
from typing import Collection
+import os
+import torch
+import re
+from collections import OrderedDict
+from functools import cmp_to_key
-from funasr.train.reporter import Reporter
+# @torch.no_grad()
+# def average_nbest_models(
+# output_dir: Path,
+# best_model_criterion: Sequence[Sequence[str]],
+# nbest: Union[Collection[int], int],
+# suffix: Optional[str] = None,
+# oss_bucket=None,
+# pai_output_dir=None,
+# ) -> None:
+# """Generate averaged model from n-best models
+#
+# Args:
+# output_dir: The directory contains the model file for each epoch
+# reporter: Reporter instance
+# best_model_criterion: Give criterions to decide the best model.
+# e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
+# nbest: Number of best model files to be averaged
+# suffix: A suffix added to the averaged model file name
+# """
+# if isinstance(nbest, int):
+# nbests = [nbest]
+# else:
+# nbests = list(nbest)
+# if len(nbests) == 0:
+# warnings.warn("At least 1 nbest values are required")
+# nbests = [1]
+# if suffix is not None:
+# suffix = suffix + "."
+# else:
+# suffix = ""
+#
+# # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
+# nbest_epochs = [
+# (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
+# for ph, k, m in best_model_criterion
+# if reporter.has(ph, k)
+# ]
+#
+# _loaded = {}
+# for ph, cr, epoch_and_values in nbest_epochs:
+# _nbests = [i for i in nbests if i <= len(epoch_and_values)]
+# if len(_nbests) == 0:
+# _nbests = [1]
+#
+# for n in _nbests:
+# if n == 0:
+# continue
+# elif n == 1:
+# # The averaged model is same as the best model
+# e, _ = epoch_and_values[0]
+# op = output_dir / f"{e}epoch.pb"
+# sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
+# if sym_op.is_symlink() or sym_op.exists():
+# sym_op.unlink()
+# sym_op.symlink_to(op.name)
+# else:
+# op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
+# logging.info(
+# f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
+# )
+#
+# avg = None
+# # 2.a. Averaging model
+# for e, _ in epoch_and_values[:n]:
+# if e not in _loaded:
+# if oss_bucket is None:
+# _loaded[e] = torch.load(
+# output_dir / f"{e}epoch.pb",
+# map_location="cpu",
+# )
+# else:
+# buffer = BytesIO(
+# oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
+# _loaded[e] = torch.load(buffer)
+# states = _loaded[e]
+#
+# if avg is None:
+# avg = states
+# else:
+# # Accumulated
+# for k in avg:
+# avg[k] = avg[k] + states[k]
+# for k in avg:
+# if str(avg[k].dtype).startswith("torch.int"):
+# # For int type, not averaged, but only accumulated.
+# # e.g. BatchNorm.num_batches_tracked
+# # (If there are any cases that requires averaging
+# # or the other reducing method, e.g. max/min, for integer type,
+# # please report.)
+# pass
+# else:
+# avg[k] = avg[k] / n
+#
+# # 2.b. Save the ave model and create a symlink
+# if oss_bucket is None:
+# torch.save(avg, op)
+# else:
+# buffer = BytesIO()
+# torch.save(avg, buffer)
+# oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
+# buffer.getvalue())
+#
+# # 3. *.*.ave.pb is a symlink to the max ave model
+# if oss_bucket is None:
+# op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
+# sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
+# if sym_op.is_symlink() or sym_op.exists():
+# sym_op.unlink()
+# sym_op.symlink_to(op.name)
+
+
+def _get_checkpoint_paths(output_dir: str, last_n: int=5):
+ """
+ Get the paths of the last 'last_n' checkpoints by parsing filenames
+ in the output directory.
+ """
+ # List all files in the output directory
+ files = os.listdir(output_dir)
+ # Filter out checkpoint files and extract epoch numbers
+ checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
+ # Sort files by epoch number in descending order
+ checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
+ # Get the last 'last_n' checkpoint paths
+ checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
+ return checkpoint_paths
@torch.no_grad()
-def average_nbest_models(
- output_dir: Path,
- reporter: Reporter,
- best_model_criterion: Sequence[Sequence[str]],
- nbest: Union[Collection[int], int],
- suffix: Optional[str] = None,
- oss_bucket=None,
- pai_output_dir=None,
-) -> None:
- """Generate averaged model from n-best models
-
- Args:
- output_dir: The directory contains the model file for each epoch
- reporter: Reporter instance
- best_model_criterion: Give criterions to decide the best model.
- e.g. [("valid", "loss", "min"), ("train", "acc", "max")]
- nbest: Number of best model files to be averaged
- suffix: A suffix added to the averaged model file name
+def average_checkpoints(output_dir: str, last_n: int=5):
"""
- if isinstance(nbest, int):
- nbests = [nbest]
- else:
- nbests = list(nbest)
- if len(nbests) == 0:
- warnings.warn("At least 1 nbest values are required")
- nbests = [1]
- if suffix is not None:
- suffix = suffix + "."
- else:
- suffix = ""
+ Average the last 'last_n' checkpoints' model state_dicts.
+ If a tensor is of type torch.int, perform sum instead of average.
+ """
+ checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
+ state_dicts = []
- # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]]
- nbest_epochs = [
- (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)])
- for ph, k, m in best_model_criterion
- if reporter.has(ph, k)
- ]
+ # Load state_dicts from checkpoints
+ for path in checkpoint_paths:
+ if os.path.isfile(path):
+ state_dicts.append(torch.load(path, map_location='cpu')['state_dict'])
+ else:
+ print(f"Checkpoint file {path} not found.")
+ continue
- _loaded = {}
- for ph, cr, epoch_and_values in nbest_epochs:
- _nbests = [i for i in nbests if i <= len(epoch_and_values)]
- if len(_nbests) == 0:
- _nbests = [1]
+ # Check if we have any state_dicts to average
+ if not state_dicts:
+ raise RuntimeError("No checkpoints found for averaging.")
- for n in _nbests:
- if n == 0:
- continue
- elif n == 1:
- # The averaged model is same as the best model
- e, _ = epoch_and_values[0]
- op = output_dir / f"{e}epoch.pb"
- sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
- if sym_op.is_symlink() or sym_op.exists():
- sym_op.unlink()
- sym_op.symlink_to(op.name)
- else:
- op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
- logging.info(
- f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
- )
-
- avg = None
- # 2.a. Averaging model
- for e, _ in epoch_and_values[:n]:
- if e not in _loaded:
- if oss_bucket is None:
- _loaded[e] = torch.load(
- output_dir / f"{e}epoch.pb",
- map_location="cpu",
- )
- else:
- buffer = BytesIO(
- oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
- _loaded[e] = torch.load(buffer)
- states = _loaded[e]
-
- if avg is None:
- avg = states
- else:
- # Accumulated
- for k in avg:
- avg[k] = avg[k] + states[k]
- for k in avg:
- if str(avg[k].dtype).startswith("torch.int"):
- # For int type, not averaged, but only accumulated.
- # e.g. BatchNorm.num_batches_tracked
- # (If there are any cases that requires averaging
- # or the other reducing method, e.g. max/min, for integer type,
- # please report.)
- pass
- else:
- avg[k] = avg[k] / n
-
- # 2.b. Save the ave model and create a symlink
- if oss_bucket is None:
- torch.save(avg, op)
- else:
- buffer = BytesIO()
- torch.save(avg, buffer)
- oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
- buffer.getvalue())
-
- # 3. *.*.ave.pb is a symlink to the max ave model
- if oss_bucket is None:
- op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
- sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
- if sym_op.is_symlink() or sym_op.exists():
- sym_op.unlink()
- sym_op.symlink_to(op.name)
+ # Average or sum weights
+ avg_state_dict = OrderedDict()
+ for key in state_dicts[0].keys():
+ tensors = [state_dict[key].cpu() for state_dict in state_dicts]
+ # Check the type of the tensor
+ if str(tensors[0].dtype).startswith("torch.int"):
+ # Perform sum for integer tensors
+ summed_tensor = sum(tensors)
+ avg_state_dict[key] = summed_tensor
+ else:
+ # Perform average for other types of tensors
+ stacked_tensors = torch.stack(tensors)
+ avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
+
+ torch.save({'state_dict': avg_state_dict}, os.path.join(output_dir, f"model.pt.avg{last_n}"))
+ return avg_state_dict
\ No newline at end of file
--
Gitblit v1.9.1