| | |
| | | from io import BytesIO |
| | | |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | from typing import Collection |
| | | |
| | | from funasr.train.reporter import Reporter |
| | |
| | | nbest: Number of best model files to be averaged |
| | | suffix: A suffix added to the averaged model file name |
| | | """ |
| | | assert check_argument_types() |
| | | if isinstance(nbest, int): |
| | | nbests = [nbest] |
| | | else: |
| | |
| | | elif n == 1: |
| | | # The averaged model is same as the best model |
| | | e, _ = epoch_and_values[0] |
| | | op = output_dir / f"{e}epoch.pth" |
| | | sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pth" |
| | | op = output_dir / f"{e}epoch.pb" |
| | | sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb" |
| | | if sym_op.is_symlink() or sym_op.exists(): |
| | | sym_op.unlink() |
| | | sym_op.symlink_to(op.name) |
| | | else: |
| | | op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pth" |
| | | op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb" |
| | | logging.info( |
| | | f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' |
| | | ) |
| | |
| | | if e not in _loaded: |
| | | if oss_bucket is None: |
| | | _loaded[e] = torch.load( |
| | | output_dir / f"{e}epoch.pth", |
| | | output_dir / f"{e}epoch.pb", |
| | | map_location="cpu", |
| | | ) |
| | | else: |
| | | buffer = BytesIO( |
| | | oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pth")).read()) |
| | | oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read()) |
| | | _loaded[e] = torch.load(buffer) |
| | | states = _loaded[e] |
| | | |
| | |
| | | else: |
| | | buffer = BytesIO() |
| | | torch.save(avg, buffer) |
| | | oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pth"), |
| | | oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"), |
| | | buffer.getvalue()) |
| | | |
| | | # 3. *.*.ave.pth is a symlink to the max ave model |
| | | # 3. *.*.ave.pb is a symlink to the max ave model |
| | | if oss_bucket is None: |
| | | op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pth" |
| | | sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pth" |
| | | op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb" |
| | | sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb" |
| | | if sym_op.is_symlink() or sym_op.exists(): |
| | | sym_op.unlink() |
| | | sym_op.symlink_to(op.name) |