| | |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis |
| | | https://arxiv.org/abs/2211.10243 |
| | | TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization |
| | | https://arxiv.org/abs/2303.05397 |
| | | """ |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | import numpy as np |
| | | import torch |
| | | import yaml |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.preprocessor import CommonPreprocessor |
| | |
| | | from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.torch_utils.initialize import initialize |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.train.class_choices import ClassChoices |
| | | from funasr.train.trainer import Trainer |
| | | from funasr.utils.types import float_or_none |
| | |
| | | sond=DiarSondModel, |
| | | eend_ola=DiarEENDOLAModel, |
| | | ), |
| | | type_check=AbsESPnetModel, |
| | | type_check=FunASRModel, |
| | | default="sond", |
| | | ) |
| | | encoder_choices = ClassChoices( |
| | |
| | | [Collection[Tuple[str, Dict[str, np.ndarray]]]], |
| | | Tuple[List[str], Dict[str, torch.Tensor]], |
| | | ]: |
| | | assert check_argument_types() |
| | | # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol |
| | | return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) |
| | | |
| | |
| | | def build_preprocess_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
| | | assert check_argument_types() |
| | | if args.use_preprocessor: |
| | | retval = CommonPreprocessor( |
| | | train=train, |
| | |
| | | ) |
| | | else: |
| | | retval = None |
| | | assert check_return_type(retval) |
| | | return retval |
| | | |
| | | @classmethod |
| | |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | retval = () |
| | | assert check_return_type(retval) |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| | | assert check_return_type(model) |
| | | return model |
| | | |
| | | # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ |
| | |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | device: Union[str, torch.device] = "cpu", |
| | | ): |
| | | """Build model from the files. |
| | | |
| | |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |
| | |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | if not isinstance(model, FunASRModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | model.to(device) |
| | | model_dict = dict() |
| | |
| | | model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | model_dict = cls.fileter_model_dict(model_dict, model.state_dict()) |
| | | model.load_state_dict(model_dict) |
| | | if model_name_pth is not None and not os.path.exists(model_name_pth): |
| | | torch.save(model_dict, model_name_pth) |
| | | logging.info("model_file is saved to pth: {}".format(model_name_pth)) |
| | | |
| | | return model, args |
| | | |
| | | @classmethod |
| | | def fileter_model_dict(cls, src_dict: dict, dest_dict: dict): |
| | | from collections import OrderedDict |
| | | new_dict = OrderedDict() |
| | | for key, value in src_dict.items(): |
| | | if key in dest_dict: |
| | | new_dict[key] = value |
| | | else: |
| | | logging.info("{} is no longer needed in this model.".format(key)) |
| | | for key, value in dest_dict.items(): |
| | | if key not in new_dict: |
| | | logging.warning("{} is missed in checkpoint.".format(key)) |
| | | return new_dict |
| | | |
| | | @classmethod |
| | | def convert_tf2torch( |
| | |
| | | [Collection[Tuple[str, Dict[str, np.ndarray]]]], |
| | | Tuple[List[str], Dict[str, torch.Tensor]], |
| | | ]: |
| | | assert check_argument_types() |
| | | # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol |
| | | return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) |
| | | |
| | |
| | | def build_preprocess_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
| | | assert check_argument_types() |
| | | # if args.use_preprocessor: |
| | | # retval = CommonPreprocessor( |
| | | # train=train, |
| | |
| | | # ) |
| | | # else: |
| | | # retval = None |
| | | # assert check_return_type(retval) |
| | | return None |
| | | |
| | | @classmethod |
| | |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | retval = () |
| | | assert check_return_type(retval) |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | |
| | | # 1. frontend |
| | | if args.input_size is None or args.frontend == "wav_frontend_mel23": |
| | |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| | | assert check_return_type(model) |
| | | return model |
| | | |
| | | # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ |
| | |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |
| | |
| | | args = yaml.safe_load(f) |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | if not isinstance(model, FunASRModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | if model_file is not None: |
| | | if device == "cuda": |