| | |
| | | |
| | | import torch |
| | | from typing import Collection |
| | | import os |
| | | import torch |
| | | import re |
| | | from collections import OrderedDict |
| | | from functools import cmp_to_key |
| | | |
| | | from funasr.train.reporter import Reporter |
| | | |
| | | def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs): |
| | | """ |
| | | Get the paths of the last 'last_n' checkpoints by parsing filenames |
| | | in the output directory. |
| | | """ |
| | | try: |
| | | if not use_deepspeed: |
| | | checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu") |
| | | else: |
| | | checkpoint = torch.load( |
| | | os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"), |
| | | map_location="cpu", |
| | | ) |
| | | avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"] |
| | | val_step_or_epoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_epoch"] |
| | | sorted_items = sorted(val_step_or_epoch.items(), key=lambda x: x[1], reverse=True) |
| | | sorted_items = ( |
| | | sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:] |
| | | ) |
| | | checkpoint_paths = [] |
| | | for key, value in sorted_items[:last_n]: |
| | | if not use_deepspeed: |
| | | ckpt = os.path.join(output_dir, key) |
| | | else: |
| | | ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt") |
| | | checkpoint_paths.append(ckpt) |
| | | |
| | | except: |
| | | print(f"{checkpoint} does not exist, avg the lastet checkpoint.") |
| | | # List all files in the output directory |
| | | files = os.listdir(output_dir) |
| | | # Filter out checkpoint files and extract epoch numbers |
| | | checkpoint_files = [f for f in files if f.startswith("model.pt.e")] |
| | | # Sort files by epoch number in descending order |
| | | checkpoint_files.sort(key=lambda x: int(re.search(r"(\d+)", x).group()), reverse=True) |
| | | # Get the last 'last_n' checkpoint paths |
| | | checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]] |
| | | |
| | | return checkpoint_paths |
| | | |
| | | |
| | | @torch.no_grad() |
| | | def average_nbest_models( |
| | | output_dir: Path, |
| | | reporter: Reporter, |
| | | best_model_criterion: Sequence[Sequence[str]], |
| | | nbest: Union[Collection[int], int], |
| | | suffix: Optional[str] = None, |
| | | oss_bucket=None, |
| | | pai_output_dir=None, |
| | | ) -> None: |
| | | """Generate averaged model from n-best models |
| | | |
| | | Args: |
| | | output_dir: The directory contains the model file for each epoch |
| | | reporter: Reporter instance |
| | | best_model_criterion: Give criterions to decide the best model. |
| | | e.g. [("valid", "loss", "min"), ("train", "acc", "max")] |
| | | nbest: Number of best model files to be averaged |
| | | suffix: A suffix added to the averaged model file name |
| | | def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs): |
| | | """ |
| | | if isinstance(nbest, int): |
| | | nbests = [nbest] |
| | | else: |
| | | nbests = list(nbest) |
| | | if len(nbests) == 0: |
| | | warnings.warn("At least 1 nbest values are required") |
| | | nbests = [1] |
| | | if suffix is not None: |
| | | suffix = suffix + "." |
| | | else: |
| | | suffix = "" |
| | | Average the last 'last_n' checkpoints' model state_dicts. |
| | | If a tensor is of type torch.int, perform sum instead of average. |
| | | """ |
| | | checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs) |
| | | print(f"average_checkpoints: {checkpoint_paths}") |
| | | state_dicts = [] |
| | | |
| | | # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] |
| | | nbest_epochs = [ |
| | | (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) |
| | | for ph, k, m in best_model_criterion |
| | | if reporter.has(ph, k) |
| | | ] |
| | | # Load state_dicts from checkpoints |
| | | for path in checkpoint_paths: |
| | | if os.path.isfile(path): |
| | | state_dicts.append(torch.load(path, map_location="cpu")["state_dict"]) |
| | | else: |
| | | print(f"Checkpoint file {path} not found.") |
| | | |
| | | _loaded = {} |
| | | for ph, cr, epoch_and_values in nbest_epochs: |
| | | _nbests = [i for i in nbests if i <= len(epoch_and_values)] |
| | | if len(_nbests) == 0: |
| | | _nbests = [1] |
| | | # Check if we have any state_dicts to average |
| | | if len(state_dicts) < 1: |
| | | print("No checkpoints found for averaging.") |
| | | return |
| | | |
| | | for n in _nbests: |
| | | if n == 0: |
| | | continue |
| | | elif n == 1: |
| | | # The averaged model is same as the best model |
| | | e, _ = epoch_and_values[0] |
| | | op = output_dir / f"{e}epoch.pb" |
| | | sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb" |
| | | if sym_op.is_symlink() or sym_op.exists(): |
| | | sym_op.unlink() |
| | | sym_op.symlink_to(op.name) |
| | | else: |
| | | op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb" |
| | | logging.info( |
| | | f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' |
| | | ) |
| | | |
| | | avg = None |
| | | # 2.a. Averaging model |
| | | for e, _ in epoch_and_values[:n]: |
| | | if e not in _loaded: |
| | | if oss_bucket is None: |
| | | _loaded[e] = torch.load( |
| | | output_dir / f"{e}epoch.pb", |
| | | map_location="cpu", |
| | | ) |
| | | else: |
| | | buffer = BytesIO( |
| | | oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read()) |
| | | _loaded[e] = torch.load(buffer) |
| | | states = _loaded[e] |
| | | |
| | | if avg is None: |
| | | avg = states |
| | | else: |
| | | # Accumulated |
| | | for k in avg: |
| | | avg[k] = avg[k] + states[k] |
| | | for k in avg: |
| | | if str(avg[k].dtype).startswith("torch.int"): |
| | | # For int type, not averaged, but only accumulated. |
| | | # e.g. BatchNorm.num_batches_tracked |
| | | # (If there are any cases that requires averaging |
| | | # or the other reducing method, e.g. max/min, for integer type, |
| | | # please report.) |
| | | pass |
| | | else: |
| | | avg[k] = avg[k] / n |
| | | |
| | | # 2.b. Save the ave model and create a symlink |
| | | if oss_bucket is None: |
| | | torch.save(avg, op) |
| | | else: |
| | | buffer = BytesIO() |
| | | torch.save(avg, buffer) |
| | | oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"), |
| | | buffer.getvalue()) |
| | | |
| | | # 3. *.*.ave.pb is a symlink to the max ave model |
| | | if oss_bucket is None: |
| | | op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb" |
| | | sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb" |
| | | if sym_op.is_symlink() or sym_op.exists(): |
| | | sym_op.unlink() |
| | | sym_op.symlink_to(op.name) |
| | | # Average or sum weights |
| | | avg_state_dict = OrderedDict() |
| | | for key in state_dicts[0].keys(): |
| | | tensors = [state_dict[key].cpu() for state_dict in state_dicts] |
| | | # Check the type of the tensor |
| | | if str(tensors[0].dtype).startswith("torch.int"): |
| | | # Perform sum for integer tensors |
| | | summed_tensor = sum(tensors) |
| | | avg_state_dict[key] = summed_tensor |
| | | else: |
| | | # Perform average for other types of tensors |
| | | stacked_tensors = torch.stack(tensors) |
| | | avg_state_dict[key] = torch.mean(stacked_tensors, dim=0) |
| | | checkpoint_outpath = os.path.join(output_dir, f"model.pt.avg{last_n}") |
| | | torch.save({"state_dict": avg_state_dict}, checkpoint_outpath) |
| | | return checkpoint_outpath |