speech_asr
2023-03-16 2ba4683eb2ce42eec91250debe88b424cbc2d67f
funasr/main_funcs/average_nbest_models.py
@@ -66,13 +66,13 @@
            elif n == 1:
                # The averaged model is same as the best model
                e, _ = epoch_and_values[0]
                op = output_dir / f"{e}epoch.pth"
                sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pth"
                op = output_dir / f"{e}epoch.pb"
                sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb"
                if sym_op.is_symlink() or sym_op.exists():
                    sym_op.unlink()
                sym_op.symlink_to(op.name)
            else:
                op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pth"
                op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb"
                logging.info(
                    f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}'
                )
@@ -83,12 +83,12 @@
                    if e not in _loaded:
                        if oss_bucket is None:
                            _loaded[e] = torch.load(
                                output_dir / f"{e}epoch.pth",
                                output_dir / f"{e}epoch.pb",
                                map_location="cpu",
                            )
                        else:
                            buffer = BytesIO(
                                oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pth")).read())
                                oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read())
                            _loaded[e] = torch.load(buffer)
                    states = _loaded[e]
@@ -115,13 +115,13 @@
                else:
                    buffer = BytesIO()
                    torch.save(avg, buffer)
                    oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pth"),
                    oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"),
                                          buffer.getvalue())
        # 3. *.*.ave.pth is a symlink to the max ave model
        # 3. *.*.ave.pb is a symlink to the max ave model
        if oss_bucket is None:
            op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pth"
            sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pth"
            op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb"
            sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb"
            if sym_op.is_symlink() or sym_op.exists():
                sym_op.unlink()
            sym_op.symlink_to(op.name)