From 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 10 十月 2023 17:17:29 +0800
Subject: [PATCH] v0.8.0

---
 funasr/tasks/diar.py |  130 +++++++++++++++++++++++++++++++++----------
 1 files changed, 100 insertions(+), 30 deletions(-)

diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 73c51e3..084b971 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -13,17 +13,15 @@
 import numpy as np
 import torch
 import yaml
-from typeguard import check_argument_types
-from typeguard import check_return_type
 
-from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.collate_fn import DiarCollateFn
 from funasr.datasets.preprocessor import CommonPreprocessor
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.layers.global_mvn import GlobalMVN
 from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.layers.label_aggregation import LabelAggregate
+from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
 from funasr.models.ctc import CTC
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
 from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
 from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
 from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
@@ -52,9 +50,10 @@
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.models.specaug.specaug import SpecAug
 from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.models.specaug.abs_profileaug import AbsProfileAug
+from funasr.models.specaug.profileaug import ProfileAug
 from funasr.tasks.abs_task import AbsTask
 from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
 from funasr.utils.types import float_or_none
@@ -84,6 +83,15 @@
     default=None,
     optional=True,
 )
+profileaug_choices = ClassChoices(
+    name="profileaug",
+    classes=dict(
+        profileaug=ProfileAug,
+    ),
+    type_check=AbsProfileAug,
+    default=None,
+    optional=True,
+)
 normalize_choices = ClassChoices(
     "normalize",
     classes=dict(
@@ -97,7 +105,8 @@
 label_aggregator_choices = ClassChoices(
     "label_aggregator",
     classes=dict(
-        label_aggregator=LabelAggregate
+        label_aggregator=LabelAggregate,
+        label_aggregator_max_pool=LabelAggregateMaxPooling,
     ),
     type_check=torch.nn.Module,
     default=None,
@@ -108,7 +117,7 @@
     classes=dict(
         sond=DiarSondModel,
     ),
-    type_check=AbsESPnetModel,
+    type_check=torch.nn.Module,
     default="sond",
 )
 encoder_choices = ClassChoices(
@@ -122,6 +131,7 @@
         fsmn=FsmnEncoder,
         conv=ConvEncoder,
         resnet34=ResNet34Diar,
+        resnet34_sp_l2reg=ResNet34SpL2RegDiar,
         sanm_chunk_opt=SANMEncoderChunkOpt,
         data2vec_encoder=Data2VecEncoder,
         ecapa_tdnn=ECAPA_TDNN,
@@ -160,6 +170,7 @@
     classes=dict(
         dot=DotScorer,
         cosine=CosScorer,
+        conv=ConvEncoder,
     ),
     type_check=torch.nn.Module,
     default=None,
@@ -187,6 +198,8 @@
         frontend_choices,
         # --specaug and --specaug_conf
         specaug_choices,
+        # --profileaug and --profileaug_conf
+        profileaug_choices,
         # --normalize and --normalize_conf
         normalize_choices,
         # --label_aggregator and --label_aggregator_conf
@@ -326,15 +339,13 @@
         [Collection[Tuple[str, Dict[str, np.ndarray]]]],
         Tuple[List[str], Dict[str, torch.Tensor]],
     ]:
-        assert check_argument_types()
         # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
-        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+        return DiarCollateFn(float_pad_value=0.0, int_pad_value=-1)
 
     @classmethod
     def build_preprocess_fn(
             cls, args: argparse.Namespace, train: bool
     ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
-        assert check_argument_types()
         if args.use_preprocessor:
             retval = CommonPreprocessor(
                 train=train,
@@ -364,7 +375,6 @@
             )
         else:
             retval = None
-        assert check_return_type(retval)
         return retval
 
     @classmethod
@@ -383,12 +393,45 @@
             cls, train: bool = True, inference: bool = False
     ) -> Tuple[str, ...]:
         retval = ()
-        assert check_return_type(retval)
         return retval
 
     @classmethod
+    def build_optimizers(
+            cls,
+            args: argparse.Namespace,
+            model: torch.nn.Module,
+    ) -> List[torch.optim.Optimizer]:
+        if cls.num_optimizers != 1:
+            raise RuntimeError(
+                "build_optimizers() must be overridden if num_optimizers != 1"
+            )
+        from funasr.tasks.abs_task import optim_classes
+        optim_class = optim_classes.get(args.optim)
+        if optim_class is None:
+            raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
+        else:
+            if (hasattr(model, "model_regularizer_weight") and
+                model.model_regularizer_weight > 0.0 and
+                hasattr(model, "get_regularize_parameters")
+            ):
+                to_regularize_parameters, normal_parameters = model.get_regularize_parameters()
+                logging.info(f"Set weight decay {model.model_regularizer_weight} for parameters: "
+                             f"{[name for name, value in to_regularize_parameters]}")
+                module_optim_config = [
+                    {"params": [value for name, value in to_regularize_parameters],
+                     "weight_decay": model.model_regularizer_weight},
+                    {"params": [value for name, value in normal_parameters],
+                     "weight_decay": 0.0}
+                ]
+                optim = optim_class(module_optim_config, **args.optim_conf)
+            else:
+                optim = optim_class(model.parameters(), **args.optim_conf)
+
+        optimizers = [optim]
+        return optimizers
+
+    @classmethod
     def build_model(cls, args: argparse.Namespace):
-        assert check_argument_types()
         if isinstance(args.token_list, str):
             with open(args.token_list, encoding="utf-8") as f:
                 token_list = [line.rstrip() for line in f]
@@ -424,6 +467,13 @@
             specaug = specaug_class(**args.specaug_conf)
         else:
             specaug = None
+
+        # 2b. Data augmentation for Profiles
+        if hasattr(args, "profileaug") and args.profileaug is not None:
+            profileaug_class = profileaug_choices.get_class(args.profileaug)
+            profileaug = profileaug_class(**args.profileaug_conf)
+        else:
+            profileaug = None
 
         # 3. Normalization layer
         if args.normalize is not None:
@@ -472,6 +522,7 @@
             vocab_size=vocab_size,
             frontend=frontend,
             specaug=specaug,
+            profileaug=profileaug,
             normalize=normalize,
             label_aggregator=label_aggregator,
             encoder=encoder,
@@ -486,8 +537,8 @@
         # 10. Initialize
         if args.init is not None:
             initialize(model, args.init)
+            logging.info(f"Init model parameters with {args.init}.")
 
-        assert check_return_type(model)
         return model
 
     # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@@ -497,7 +548,7 @@
             config_file: Union[Path, str] = None,
             model_file: Union[Path, str] = None,
             cmvn_file: Union[Path, str] = None,
-            device: str = "cpu",
+            device: Union[str, torch.device] = "cpu",
     ):
         """Build model from the files.
 
@@ -510,7 +561,6 @@
             device: Device type, "cpu", "cuda", or "cuda:N".
 
         """
-        assert check_argument_types()
         if config_file is None:
             assert model_file is not None, (
                 "The argument 'model_file' must be provided "
@@ -526,9 +576,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, torch.nn.Module):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {torch.nn.Module.__name__}, but got {type(model)}"
             )
         model.to(device)
         model_dict = dict()
@@ -549,15 +599,30 @@
                     model_dict = torch.load(model_name_pth, map_location=device)
                 else:
                     model_dict = cls.convert_tf2torch(model, model_file)
-                model.load_state_dict(model_dict)
+                # model.load_state_dict(model_dict)
             else:
                 model_dict = torch.load(model_file, map_location=device)
+        model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
         model.load_state_dict(model_dict)
         if model_name_pth is not None and not os.path.exists(model_name_pth):
             torch.save(model_dict, model_name_pth)
             logging.info("model_file is saved to pth: {}".format(model_name_pth))
 
         return model, args
+
+    @classmethod
+    def fileter_model_dict(cls, src_dict: dict, dest_dict: dict):
+        from collections import OrderedDict
+        new_dict = OrderedDict()
+        for key, value in src_dict.items():
+            if key in dest_dict:
+                new_dict[key] = value
+            else:
+                logging.info("{} is no longer needed in this model.".format(key))
+        for key, value in dest_dict.items():
+            if key not in new_dict:
+                logging.warning("{} is missed in checkpoint.".format(key))
+        return new_dict
 
     @classmethod
     def convert_tf2torch(
@@ -571,19 +636,24 @@
         var_dict_torch = model.state_dict()
         var_dict_torch_update = dict()
         # speech encoder
-        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.encoder is not None:
+            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # speaker encoder
-        var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.speaker_encoder is not None:
+            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # cd scorer
-        var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.cd_scorer is not None:
+            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # ci scorer
-        var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.ci_scorer is not None:
+            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
         # decoder
-        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
+        if model.decoder is not None:
+            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+            var_dict_torch_update.update(var_dict_torch_update_local)
 
         return var_dict_torch_update

--
Gitblit v1.9.1