From 4bc6db3ef88795eb570f92f9576f8bc7c56f96bc Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期二, 01 八月 2023 17:03:39 +0800
Subject: [PATCH] TOLD: add TOLD/SOND recipe on callhome

---
 funasr/tasks/diar.py |  414 ++++++++++++----------------------------------------------
 1 files changed, 90 insertions(+), 324 deletions(-)

diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index a486a46..2d10435 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -1,11 +1,3 @@
-"""
-Author: Speech Lab, Alibaba Group, China
-SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
-https://arxiv.org/abs/2211.10243
-TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
-https://arxiv.org/abs/2303.05397
-"""
-
 import argparse
 import logging
 import os
@@ -21,24 +13,26 @@
 import numpy as np
 import torch
 import yaml
+from typeguard import check_argument_types
+from typeguard import check_return_type
 
-from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.collate_fn import DiarCollateFn
 from funasr.datasets.preprocessor import CommonPreprocessor
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.label_aggregation import LabelAggregate
 from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.e2e_diar_sond import DiarSondModel
-from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
+from funasr.models.ctc import CTC
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
 from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
-from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
 from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
 from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
 from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
+from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
+from funasr.models.e2e_diar_sond import DiarSondModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
 from funasr.models.encoder.rnn_encoder import RNNEncoder
 from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
 from funasr.models.encoder.transformer_encoder import TransformerEncoder
@@ -47,16 +41,21 @@
 from funasr.models.frontend.fused import FusedFrontends
 from funasr.models.frontend.s3prl import S3prlFrontend
 from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.frontend.wav_frontend import WavFrontendMel23
 from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.postencoder.hugging_face_transformers_postencoder import (
+    HuggingFaceTransformersPostEncoder,  # noqa: H301
+)
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.linear import LinearProjection
+from funasr.models.preencoder.sinc import LightweightSincConvs
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.models.specaug.specaug import SpecAug
 from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
-from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
+from funasr.models.specaug.abs_profileaug import AbsProfileAug
+from funasr.models.specaug.profileaug import ProfileAug
 from funasr.tasks.abs_task import AbsTask
 from funasr.torch_utils.initialize import initialize
-from funasr.models.base_model import FunASRModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
 from funasr.utils.types import float_or_none
@@ -72,7 +71,6 @@
         s3prl=S3prlFrontend,
         fused=FusedFrontends,
         wav_frontend=WavFrontend,
-        wav_frontend_mel23=WavFrontendMel23,
     ),
     type_check=AbsFrontend,
     default="default",
@@ -84,6 +82,15 @@
         specaug_lfr=SpecAugLFR,
     ),
     type_check=AbsSpecAug,
+    default=None,
+    optional=True,
+)
+profileaug_choices = ClassChoices(
+    name="profileaug",
+    classes=dict(
+        profileaug=ProfileAug,
+    ),
+    type_check=AbsProfileAug,
     default=None,
     optional=True,
 )
@@ -100,7 +107,8 @@
 label_aggregator_choices = ClassChoices(
     "label_aggregator",
     classes=dict(
-        label_aggregator=LabelAggregate
+        label_aggregator=LabelAggregate,
+        label_aggregator_max_pool=LabelAggregateMaxPooling,
     ),
     type_check=torch.nn.Module,
     default=None,
@@ -110,9 +118,8 @@
     "model",
     classes=dict(
         sond=DiarSondModel,
-        eend_ola=DiarEENDOLAModel,
     ),
-    type_check=FunASRModel,
+    type_check=torch.nn.Module,
     default="sond",
 )
 encoder_choices = ClassChoices(
@@ -130,7 +137,6 @@
         sanm_chunk_opt=SANMEncoderChunkOpt,
         data2vec_encoder=Data2VecEncoder,
         ecapa_tdnn=ECAPA_TDNN,
-        eend_ola_transformer=EENDOLATransformerEncoder,
     ),
     type_check=torch.nn.Module,
     default="resnet34",
@@ -182,15 +188,6 @@
     type_check=torch.nn.Module,
     default="fsmn",
 )
-# encoder_decoder_attractor is used for EEND-OLA
-encoder_decoder_attractor_choices = ClassChoices(
-    "encoder_decoder_attractor",
-    classes=dict(
-        eda=EncoderDecoderAttractor,
-    ),
-    type_check=torch.nn.Module,
-    default="eda",
-)
 
 
 class DiarTask(AbsTask):
@@ -203,6 +200,8 @@
         frontend_choices,
         # --specaug and --specaug_conf
         specaug_choices,
+        # --profileaug and --profileaug_conf
+        profileaug_choices,
         # --normalize and --normalize_conf
         normalize_choices,
         # --label_aggregator and --label_aggregator_conf
@@ -342,13 +341,15 @@
         [Collection[Tuple[str, Dict[str, np.ndarray]]]],
         Tuple[List[str], Dict[str, torch.Tensor]],
     ]:
+        assert check_argument_types()
         # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
-        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+        return DiarCollateFn(float_pad_value=0.0, int_pad_value=-1)
 
     @classmethod
     def build_preprocess_fn(
             cls, args: argparse.Namespace, train: bool
     ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+        assert check_argument_types()
         if args.use_preprocessor:
             retval = CommonPreprocessor(
                 train=train,
@@ -378,6 +379,7 @@
             )
         else:
             retval = None
+        assert check_return_type(retval)
         return retval
 
     @classmethod
@@ -396,10 +398,47 @@
             cls, train: bool = True, inference: bool = False
     ) -> Tuple[str, ...]:
         retval = ()
+        assert check_return_type(retval)
         return retval
 
     @classmethod
+    def build_optimizers(
+            cls,
+            args: argparse.Namespace,
+            model: torch.nn.Module,
+    ) -> List[torch.optim.Optimizer]:
+        if cls.num_optimizers != 1:
+            raise RuntimeError(
+                "build_optimizers() must be overridden if num_optimizers != 1"
+            )
+        from funasr.tasks.abs_task import optim_classes
+        optim_class = optim_classes.get(args.optim)
+        if optim_class is None:
+            raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
+        else:
+            if (hasattr(model, "model_regularizer_weight") and
+                model.model_regularizer_weight > 0.0 and
+                hasattr(model, "get_regularize_parameters")
+            ):
+                to_regularize_parameters, normal_parameters = model.get_regularize_parameters()
+                logging.info(f"Set weight decay {model.model_regularizer_weight} for parameters: "
+                             f"{[name for name, value in to_regularize_parameters]}")
+                module_optim_config = [
+                    {"params": [value for name, value in to_regularize_parameters],
+                     "weight_decay": model.model_regularizer_weight},
+                    {"params": [value for name, value in normal_parameters],
+                     "weight_decay": 0.0}
+                ]
+                optim = optim_class(module_optim_config, **args.optim_conf)
+            else:
+                optim = optim_class(model.parameters(), **args.optim_conf)
+
+        optimizers = [optim]
+        return optimizers
+
+    @classmethod
     def build_model(cls, args: argparse.Namespace):
+        assert check_argument_types()
         if isinstance(args.token_list, str):
             with open(args.token_list, encoding="utf-8") as f:
                 token_list = [line.rstrip() for line in f]
@@ -435,6 +474,13 @@
             specaug = specaug_class(**args.specaug_conf)
         else:
             specaug = None
+
+        # 2b. Data augmentation for Profiles
+        if hasattr(args, "profileaug") and args.profileaug is not None:
+            profileaug_class = profileaug_choices.get_class(args.profileaug)
+            profileaug = profileaug_class(**args.profileaug_conf)
+        else:
+            profileaug = None
 
         # 3. Normalization layer
         if args.normalize is not None:
@@ -483,6 +529,7 @@
             vocab_size=vocab_size,
             frontend=frontend,
             specaug=specaug,
+            profileaug=profileaug,
             normalize=normalize,
             label_aggregator=label_aggregator,
             encoder=encoder,
@@ -497,7 +544,9 @@
         # 10. Initialize
         if args.init is not None:
             initialize(model, args.init)
+            logging.info(f"Init model parameters with {args.init}.")
 
+        assert check_return_type(model)
         return model
 
     # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@@ -520,6 +569,7 @@
             device: Device type, "cpu", "cuda", or "cuda:N".
 
         """
+        assert check_argument_types()
         if config_file is None:
             assert model_file is not None, (
                 "The argument 'model_file' must be provided "
@@ -535,9 +585,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, FunASRModel):
+        if not isinstance(model, torch.nn.Module):
             raise RuntimeError(
-                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
+                f"model must inherit {torch.nn.Module.__name__}, but got {type(model)}"
             )
         model.to(device)
         model_dict = dict()
@@ -552,13 +602,13 @@
                 if ".bin" in model_name:
                     model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
                 else:
-                    model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
+                    model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
                 if os.path.exists(model_name_pth):
                     logging.info("model_file is load from pth: {}".format(model_name_pth))
                     model_dict = torch.load(model_name_pth, map_location=device)
                 else:
                     model_dict = cls.convert_tf2torch(model, model_file)
-                model.load_state_dict(model_dict)
+                # model.load_state_dict(model_dict)
             else:
                 model_dict = torch.load(model_file, map_location=device)
         model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
@@ -616,287 +666,3 @@
             var_dict_torch_update.update(var_dict_torch_update_local)
 
         return var_dict_torch_update
-
-
-class EENDOLADiarTask(AbsTask):
-    # If you need more than 1 optimizer, change this value
-    num_optimizers: int = 1
-
-    # Add variable objects configurations
-    class_choices_list = [
-        # --frontend and --frontend_conf
-        frontend_choices,
-        # --specaug and --specaug_conf
-        model_choices,
-        # --encoder and --encoder_conf
-        encoder_choices,
-        # --speaker_encoder and --speaker_encoder_conf
-        encoder_decoder_attractor_choices,
-    ]
-
-    # If you need to modify train() or eval() procedures, change Trainer class here
-    trainer = Trainer
-
-    @classmethod
-    def add_task_arguments(cls, parser: argparse.ArgumentParser):
-        group = parser.add_argument_group(description="Task related")
-
-        # NOTE(kamo): add_arguments(..., required=True) can't be used
-        # to provide --print_config mode. Instead of it, do as
-        # required = parser.get_default("required")
-        # required += ["token_list"]
-
-        group.add_argument(
-            "--token_list",
-            type=str_or_none,
-            default=None,
-            help="A text mapping int-id to token",
-        )
-        group.add_argument(
-            "--split_with_space",
-            type=str2bool,
-            default=True,
-            help="whether to split text using <space>",
-        )
-        group.add_argument(
-            "--seg_dict_file",
-            type=str,
-            default=None,
-            help="seg_dict_file for text processing",
-        )
-        group.add_argument(
-            "--init",
-            type=lambda x: str_or_none(x.lower()),
-            default=None,
-            help="The initialization method",
-            choices=[
-                "chainer",
-                "xavier_uniform",
-                "xavier_normal",
-                "kaiming_uniform",
-                "kaiming_normal",
-                None,
-            ],
-        )
-
-        group.add_argument(
-            "--input_size",
-            type=int_or_none,
-            default=None,
-            help="The number of input dimension of the feature",
-        )
-
-        group = parser.add_argument_group(description="Preprocess related")
-        group.add_argument(
-            "--use_preprocessor",
-            type=str2bool,
-            default=True,
-            help="Apply preprocessing to data or not",
-        )
-        group.add_argument(
-            "--token_type",
-            type=str,
-            default="char",
-            choices=["char"],
-            help="The text will be tokenized in the specified level token",
-        )
-        parser.add_argument(
-            "--speech_volume_normalize",
-            type=float_or_none,
-            default=None,
-            help="Scale the maximum amplitude to the given value.",
-        )
-        parser.add_argument(
-            "--rir_scp",
-            type=str_or_none,
-            default=None,
-            help="The file path of rir scp file.",
-        )
-        parser.add_argument(
-            "--rir_apply_prob",
-            type=float,
-            default=1.0,
-            help="THe probability for applying RIR convolution.",
-        )
-        parser.add_argument(
-            "--cmvn_file",
-            type=str_or_none,
-            default=None,
-            help="The file path of noise scp file.",
-        )
-        parser.add_argument(
-            "--noise_scp",
-            type=str_or_none,
-            default=None,
-            help="The file path of noise scp file.",
-        )
-        parser.add_argument(
-            "--noise_apply_prob",
-            type=float,
-            default=1.0,
-            help="The probability applying Noise adding.",
-        )
-        parser.add_argument(
-            "--noise_db_range",
-            type=str,
-            default="13_15",
-            help="The range of noise decibel level.",
-        )
-
-        for class_choices in cls.class_choices_list:
-            # Append --<name> and --<name>_conf.
-            # e.g. --encoder and --encoder_conf
-            class_choices.add_arguments(group)
-
-    @classmethod
-    def build_collate_fn(
-            cls, args: argparse.Namespace, train: bool
-    ) -> Callable[
-        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
-        Tuple[List[str], Dict[str, torch.Tensor]],
-    ]:
-        # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
-        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
-
-    @classmethod
-    def build_preprocess_fn(
-            cls, args: argparse.Namespace, train: bool
-    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
-        # if args.use_preprocessor:
-        #     retval = CommonPreprocessor(
-        #         train=train,
-        #         token_type=args.token_type,
-        #         token_list=args.token_list,
-        #         bpemodel=None,
-        #         non_linguistic_symbols=None,
-        #         text_cleaner=None,
-        #         g2p_type=None,
-        #         split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
-        #         seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
-        #         # NOTE(kamo): Check attribute existence for backward compatibility
-        #         rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
-        #         rir_apply_prob=args.rir_apply_prob
-        #         if hasattr(args, "rir_apply_prob")
-        #         else 1.0,
-        #         noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
-        #         noise_apply_prob=args.noise_apply_prob
-        #         if hasattr(args, "noise_apply_prob")
-        #         else 1.0,
-        #         noise_db_range=args.noise_db_range
-        #         if hasattr(args, "noise_db_range")
-        #         else "13_15",
-        #         speech_volume_normalize=args.speech_volume_normalize
-        #         if hasattr(args, "rir_scp")
-        #         else None,
-        #     )
-        # else:
-        #     retval = None
-        return None
-
-    @classmethod
-    def required_data_names(
-            cls, train: bool = True, inference: bool = False
-    ) -> Tuple[str, ...]:
-        if not inference:
-            retval = ("speech", )
-        else:
-            # Recognition mode
-            retval = ("speech", )
-        return retval
-
-    @classmethod
-    def optional_data_names(
-            cls, train: bool = True, inference: bool = False
-    ) -> Tuple[str, ...]:
-        retval = ()
-        return retval
-
-    @classmethod
-    def build_model(cls, args: argparse.Namespace):
-
-        # 1. frontend
-        if args.input_size is None or args.frontend == "wav_frontend_mel23":
-            # Extract features in the model
-            frontend_class = frontend_choices.get_class(args.frontend)
-            if args.frontend == 'wav_frontend':
-                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
-            else:
-                frontend = frontend_class(**args.frontend_conf)
-            input_size = frontend.output_size()
-        else:
-            # Give features from data-loader
-            args.frontend = None
-            args.frontend_conf = {}
-            frontend = None
-            input_size = args.input_size
-
-        # 2. Encoder
-        encoder_class = encoder_choices.get_class(args.encoder)
-        encoder = encoder_class(**args.encoder_conf)
-
-        # 3. EncoderDecoderAttractor
-        encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
-        encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
-
-        # 9. Build model
-        model_class = model_choices.get_class(args.model)
-        model = model_class(
-            frontend=frontend,
-            encoder=encoder,
-            encoder_decoder_attractor=encoder_decoder_attractor,
-            **args.model_conf,
-        )
-
-        # 10. Initialize
-        if args.init is not None:
-            initialize(model, args.init)
-
-        return model
-
-    # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
-    @classmethod
-    def build_model_from_file(
-            cls,
-            config_file: Union[Path, str] = None,
-            model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            device: str = "cpu",
-    ):
-        """Build model from the files.
-
-        This method is used for inference or fine-tuning.
-
-        Args:
-            config_file: The yaml file saved when training.
-            model_file: The model file saved when training.
-            cmvn_file: The cmvn file for front-end
-            device: Device type, "cpu", "cuda", or "cuda:N".
-
-        """
-        if config_file is None:
-            assert model_file is not None, (
-                "The argument 'model_file' must be provided "
-                "if the argument 'config_file' is not specified."
-            )
-            config_file = Path(model_file).parent / "config.yaml"
-        else:
-            config_file = Path(config_file)
-
-        with config_file.open("r", encoding="utf-8") as f:
-            args = yaml.safe_load(f)
-        args = argparse.Namespace(**args)
-        model = cls.build_model(args)
-        if not isinstance(model, FunASRModel):
-            raise RuntimeError(
-                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
-            )
-        if model_file is not None:
-            if device == "cuda":
-                device = f"cuda:{torch.cuda.current_device()}"
-            checkpoint = torch.load(model_file, map_location=device)
-            if "state_dict" in checkpoint.keys():
-                model.load_state_dict(checkpoint["state_dict"])
-            else:
-                model.load_state_dict(checkpoint)
-        model.to(device)
-        return model, args

--
Gitblit v1.9.1