| | |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis |
| | | https://arxiv.org/abs/2211.10243 |
| | | TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization |
| | | https://arxiv.org/abs/2303.05397 |
| | | """ |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | |
| | | import numpy as np |
| | | import torch |
| | | import yaml |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.collate_fn import DiarCollateFn |
| | | from funasr.datasets.preprocessor import CommonPreprocessor |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.layers.global_mvn import GlobalMVN |
| | | from funasr.layers.label_aggregation import LabelAggregate |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.models.e2e_diar_sond import DiarSondModel |
| | | from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.models.encoder.data2vec_encoder import Data2VecEncoder |
| | | from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar |
| | | from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN |
| | | from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer |
| | | from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder |
| | | from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder |
| | | from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar |
| | | from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer |
| | | from funasr.models.e2e_diar_sond import DiarSondModel |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.models.encoder.data2vec_encoder import Data2VecEncoder |
| | | from funasr.models.encoder.rnn_encoder import RNNEncoder |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt |
| | | from funasr.models.encoder.transformer_encoder import TransformerEncoder |
| | |
| | | from funasr.models.frontend.fused import FusedFrontends |
| | | from funasr.models.frontend.s3prl import S3prlFrontend |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.frontend.wav_frontend import WavFrontendMel23 |
| | | from funasr.models.frontend.windowing import SlidingWindow |
| | | from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | | from funasr.models.postencoder.hugging_face_transformers_postencoder import ( |
| | | HuggingFaceTransformersPostEncoder, # noqa: H301 |
| | | ) |
| | | from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | from funasr.models.preencoder.linear import LinearProjection |
| | | from funasr.models.preencoder.sinc import LightweightSincConvs |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | | from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder |
| | | from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor |
| | | from funasr.models.specaug.abs_profileaug import AbsProfileAug |
| | | from funasr.models.specaug.profileaug import ProfileAug |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.torch_utils.initialize import initialize |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.train.class_choices import ClassChoices |
| | | from funasr.train.trainer import Trainer |
| | | from funasr.utils.types import float_or_none |
| | |
| | | s3prl=S3prlFrontend, |
| | | fused=FusedFrontends, |
| | | wav_frontend=WavFrontend, |
| | | wav_frontend_mel23=WavFrontendMel23, |
| | | ), |
| | | type_check=AbsFrontend, |
| | | default="default", |
| | |
| | | specaug_lfr=SpecAugLFR, |
| | | ), |
| | | type_check=AbsSpecAug, |
| | | default=None, |
| | | optional=True, |
| | | ) |
| | | profileaug_choices = ClassChoices( |
| | | name="profileaug", |
| | | classes=dict( |
| | | profileaug=ProfileAug, |
| | | ), |
| | | type_check=AbsProfileAug, |
| | | default=None, |
| | | optional=True, |
| | | ) |
| | |
| | | label_aggregator_choices = ClassChoices( |
| | | "label_aggregator", |
| | | classes=dict( |
| | | label_aggregator=LabelAggregate |
| | | label_aggregator=LabelAggregate, |
| | | label_aggregator_max_pool=LabelAggregateMaxPooling, |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default=None, |
| | |
| | | "model", |
| | | classes=dict( |
| | | sond=DiarSondModel, |
| | | eend_ola=DiarEENDOLAModel, |
| | | ), |
| | | type_check=FunASRModel, |
| | | type_check=torch.nn.Module, |
| | | default="sond", |
| | | ) |
| | | encoder_choices = ClassChoices( |
| | |
| | | sanm_chunk_opt=SANMEncoderChunkOpt, |
| | | data2vec_encoder=Data2VecEncoder, |
| | | ecapa_tdnn=ECAPA_TDNN, |
| | | eend_ola_transformer=EENDOLATransformerEncoder, |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default="resnet34", |
| | |
| | | type_check=torch.nn.Module, |
| | | default="fsmn", |
| | | ) |
| | | # encoder_decoder_attractor is used for EEND-OLA |
| | | encoder_decoder_attractor_choices = ClassChoices( |
| | | "encoder_decoder_attractor", |
| | | classes=dict( |
| | | eda=EncoderDecoderAttractor, |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default="eda", |
| | | ) |
| | | |
| | | |
| | | class DiarTask(AbsTask): |
| | |
| | | frontend_choices, |
| | | # --specaug and --specaug_conf |
| | | specaug_choices, |
| | | # --profileaug and --profileaug_conf |
| | | profileaug_choices, |
| | | # --normalize and --normalize_conf |
| | | normalize_choices, |
| | | # --label_aggregator and --label_aggregator_conf |
| | |
| | | [Collection[Tuple[str, Dict[str, np.ndarray]]]], |
| | | Tuple[List[str], Dict[str, torch.Tensor]], |
| | | ]: |
| | | assert check_argument_types() |
| | | # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol |
| | | return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) |
| | | return DiarCollateFn(float_pad_value=0.0, int_pad_value=-1) |
| | | |
| | | @classmethod |
| | | def build_preprocess_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
| | | assert check_argument_types() |
| | | if args.use_preprocessor: |
| | | retval = CommonPreprocessor( |
| | | train=train, |
| | |
| | | ) |
| | | else: |
| | | retval = None |
| | | assert check_return_type(retval) |
| | | return retval |
| | | |
| | | @classmethod |
| | |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | retval = () |
| | | assert check_return_type(retval) |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_optimizers( |
| | | cls, |
| | | args: argparse.Namespace, |
| | | model: torch.nn.Module, |
| | | ) -> List[torch.optim.Optimizer]: |
| | | if cls.num_optimizers != 1: |
| | | raise RuntimeError( |
| | | "build_optimizers() must be overridden if num_optimizers != 1" |
| | | ) |
| | | from funasr.tasks.abs_task import optim_classes |
| | | optim_class = optim_classes.get(args.optim) |
| | | if optim_class is None: |
| | | raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}") |
| | | else: |
| | | if (hasattr(model, "model_regularizer_weight") and |
| | | model.model_regularizer_weight > 0.0 and |
| | | hasattr(model, "get_regularize_parameters") |
| | | ): |
| | | to_regularize_parameters, normal_parameters = model.get_regularize_parameters() |
| | | logging.info(f"Set weight decay {model.model_regularizer_weight} for parameters: " |
| | | f"{[name for name, value in to_regularize_parameters]}") |
| | | module_optim_config = [ |
| | | {"params": [value for name, value in to_regularize_parameters], |
| | | "weight_decay": model.model_regularizer_weight}, |
| | | {"params": [value for name, value in normal_parameters], |
| | | "weight_decay": 0.0} |
| | | ] |
| | | optim = optim_class(module_optim_config, **args.optim_conf) |
| | | else: |
| | | optim = optim_class(model.parameters(), **args.optim_conf) |
| | | |
| | | optimizers = [optim] |
| | | return optimizers |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | |
| | | specaug = specaug_class(**args.specaug_conf) |
| | | else: |
| | | specaug = None |
| | | |
| | | # 2b. Data augmentation for Profiles |
| | | if hasattr(args, "profileaug") and args.profileaug is not None: |
| | | profileaug_class = profileaug_choices.get_class(args.profileaug) |
| | | profileaug = profileaug_class(**args.profileaug_conf) |
| | | else: |
| | | profileaug = None |
| | | |
| | | # 3. Normalization layer |
| | | if args.normalize is not None: |
| | |
| | | vocab_size=vocab_size, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | profileaug=profileaug, |
| | | normalize=normalize, |
| | | label_aggregator=label_aggregator, |
| | | encoder=encoder, |
| | |
| | | # 10. Initialize |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | logging.info(f"Init model parameters with {args.init}.") |
| | | |
| | | assert check_return_type(model) |
| | | return model |
| | | |
| | | # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ |
| | |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |
| | |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, FunASRModel): |
| | | if not isinstance(model, torch.nn.Module): |
| | | raise RuntimeError( |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {torch.nn.Module.__name__}, but got {type(model)}" |
| | | ) |
| | | model.to(device) |
| | | model_dict = dict() |
| | |
| | | if ".bin" in model_name: |
| | | model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb')) |
| | | else: |
| | | model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name)) |
| | | model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name)) |
| | | if os.path.exists(model_name_pth): |
| | | logging.info("model_file is load from pth: {}".format(model_name_pth)) |
| | | model_dict = torch.load(model_name_pth, map_location=device) |
| | | else: |
| | | model_dict = cls.convert_tf2torch(model, model_file) |
| | | model.load_state_dict(model_dict) |
| | | # model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | model_dict = cls.fileter_model_dict(model_dict, model.state_dict()) |
| | |
| | | var_dict_torch_update.update(var_dict_torch_update_local) |
| | | |
| | | return var_dict_torch_update |
| | | |
| | | |
| | | class EENDOLADiarTask(AbsTask): |
| | | # If you need more than 1 optimizer, change this value |
| | | num_optimizers: int = 1 |
| | | |
| | | # Add variable objects configurations |
| | | class_choices_list = [ |
| | | # --frontend and --frontend_conf |
| | | frontend_choices, |
| | | # --specaug and --specaug_conf |
| | | model_choices, |
| | | # --encoder and --encoder_conf |
| | | encoder_choices, |
| | | # --speaker_encoder and --speaker_encoder_conf |
| | | encoder_decoder_attractor_choices, |
| | | ] |
| | | |
| | | # If you need to modify train() or eval() procedures, change Trainer class here |
| | | trainer = Trainer |
| | | |
| | | @classmethod |
| | | def add_task_arguments(cls, parser: argparse.ArgumentParser): |
| | | group = parser.add_argument_group(description="Task related") |
| | | |
| | | # NOTE(kamo): add_arguments(..., required=True) can't be used |
| | | # to provide --print_config mode. Instead of it, do as |
| | | # required = parser.get_default("required") |
| | | # required += ["token_list"] |
| | | |
| | | group.add_argument( |
| | | "--token_list", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="A text mapping int-id to token", |
| | | ) |
| | | group.add_argument( |
| | | "--split_with_space", |
| | | type=str2bool, |
| | | default=True, |
| | | help="whether to split text using <space>", |
| | | ) |
| | | group.add_argument( |
| | | "--seg_dict_file", |
| | | type=str, |
| | | default=None, |
| | | help="seg_dict_file for text processing", |
| | | ) |
| | | group.add_argument( |
| | | "--init", |
| | | type=lambda x: str_or_none(x.lower()), |
| | | default=None, |
| | | help="The initialization method", |
| | | choices=[ |
| | | "chainer", |
| | | "xavier_uniform", |
| | | "xavier_normal", |
| | | "kaiming_uniform", |
| | | "kaiming_normal", |
| | | None, |
| | | ], |
| | | ) |
| | | |
| | | group.add_argument( |
| | | "--input_size", |
| | | type=int_or_none, |
| | | default=None, |
| | | help="The number of input dimension of the feature", |
| | | ) |
| | | |
| | | group = parser.add_argument_group(description="Preprocess related") |
| | | group.add_argument( |
| | | "--use_preprocessor", |
| | | type=str2bool, |
| | | default=True, |
| | | help="Apply preprocessing to data or not", |
| | | ) |
| | | group.add_argument( |
| | | "--token_type", |
| | | type=str, |
| | | default="char", |
| | | choices=["char"], |
| | | help="The text will be tokenized in the specified level token", |
| | | ) |
| | | parser.add_argument( |
| | | "--speech_volume_normalize", |
| | | type=float_or_none, |
| | | default=None, |
| | | help="Scale the maximum amplitude to the given value.", |
| | | ) |
| | | parser.add_argument( |
| | | "--rir_scp", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The file path of rir scp file.", |
| | | ) |
| | | parser.add_argument( |
| | | "--rir_apply_prob", |
| | | type=float, |
| | | default=1.0, |
| | | help="THe probability for applying RIR convolution.", |
| | | ) |
| | | parser.add_argument( |
| | | "--cmvn_file", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The file path of noise scp file.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_scp", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The file path of noise scp file.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_apply_prob", |
| | | type=float, |
| | | default=1.0, |
| | | help="The probability applying Noise adding.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_db_range", |
| | | type=str, |
| | | default="13_15", |
| | | help="The range of noise decibel level.", |
| | | ) |
| | | |
| | | for class_choices in cls.class_choices_list: |
| | | # Append --<name> and --<name>_conf. |
| | | # e.g. --encoder and --encoder_conf |
| | | class_choices.add_arguments(group) |
| | | |
| | | @classmethod |
| | | def build_collate_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Callable[ |
| | | [Collection[Tuple[str, Dict[str, np.ndarray]]]], |
| | | Tuple[List[str], Dict[str, torch.Tensor]], |
| | | ]: |
| | | # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol |
| | | return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) |
| | | |
| | | @classmethod |
| | | def build_preprocess_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
| | | # if args.use_preprocessor: |
| | | # retval = CommonPreprocessor( |
| | | # train=train, |
| | | # token_type=args.token_type, |
| | | # token_list=args.token_list, |
| | | # bpemodel=None, |
| | | # non_linguistic_symbols=None, |
| | | # text_cleaner=None, |
| | | # g2p_type=None, |
| | | # split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, |
| | | # seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None, |
| | | # # NOTE(kamo): Check attribute existence for backward compatibility |
| | | # rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, |
| | | # rir_apply_prob=args.rir_apply_prob |
| | | # if hasattr(args, "rir_apply_prob") |
| | | # else 1.0, |
| | | # noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, |
| | | # noise_apply_prob=args.noise_apply_prob |
| | | # if hasattr(args, "noise_apply_prob") |
| | | # else 1.0, |
| | | # noise_db_range=args.noise_db_range |
| | | # if hasattr(args, "noise_db_range") |
| | | # else "13_15", |
| | | # speech_volume_normalize=args.speech_volume_normalize |
| | | # if hasattr(args, "rir_scp") |
| | | # else None, |
| | | # ) |
| | | # else: |
| | | # retval = None |
| | | return None |
| | | |
| | | @classmethod |
| | | def required_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | if not inference: |
| | | retval = ("speech", ) |
| | | else: |
| | | # Recognition mode |
| | | retval = ("speech", ) |
| | | return retval |
| | | |
| | | @classmethod |
| | | def optional_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | retval = () |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace): |
| | | |
| | | # 1. frontend |
| | | if args.input_size is None or args.frontend == "wav_frontend_mel23": |
| | | # Extract features in the model |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | if args.frontend == 'wav_frontend': |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | | else: |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | # Give features from data-loader |
| | | args.frontend = None |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | # 2. Encoder |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(**args.encoder_conf) |
| | | |
| | | # 3. EncoderDecoderAttractor |
| | | encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor) |
| | | encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf) |
| | | |
| | | # 9. Build model |
| | | model_class = model_choices.get_class(args.model) |
| | | model = model_class( |
| | | frontend=frontend, |
| | | encoder=encoder, |
| | | encoder_decoder_attractor=encoder_decoder_attractor, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | # 10. Initialize |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| | | return model |
| | | |
| | | # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ |
| | | @classmethod |
| | | def build_model_from_file( |
| | | cls, |
| | | config_file: Union[Path, str] = None, |
| | | model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | device: str = "cpu", |
| | | ): |
| | | """Build model from the files. |
| | | |
| | | This method is used for inference or fine-tuning. |
| | | |
| | | Args: |
| | | config_file: The yaml file saved when training. |
| | | model_file: The model file saved when training. |
| | | cmvn_file: The cmvn file for front-end |
| | | device: Device type, "cpu", "cuda", or "cuda:N". |
| | | |
| | | """ |
| | | if config_file is None: |
| | | assert model_file is not None, ( |
| | | "The argument 'model_file' must be provided " |
| | | "if the argument 'config_file' is not specified." |
| | | ) |
| | | config_file = Path(model_file).parent / "config.yaml" |
| | | else: |
| | | config_file = Path(config_file) |
| | | |
| | | with config_file.open("r", encoding="utf-8") as f: |
| | | args = yaml.safe_load(f) |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, FunASRModel): |
| | | raise RuntimeError( |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | if model_file is not None: |
| | | if device == "cuda": |
| | | device = f"cuda:{torch.cuda.current_device()}" |
| | | checkpoint = torch.load(model_file, map_location=device) |
| | | if "state_dict" in checkpoint.keys(): |
| | | model.load_state_dict(checkpoint["state_dict"]) |
| | | else: |
| | | model.load_state_dict(checkpoint) |
| | | model.to(device) |
| | | return model, args |