From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/tasks/diar.py | 84 +++++++++++++++++++++++++++++++++---------
1 files changed, 66 insertions(+), 18 deletions(-)
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index bf3ae41..084b971 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -13,15 +13,13 @@
import numpy as np
import torch
import yaml
-from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.collate_fn import DiarCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.layers.label_aggregation import LabelAggregate
+from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
from funasr.models.ctc import CTC
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
@@ -52,9 +50,10 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.models.specaug.abs_profileaug import AbsProfileAug
+from funasr.models.specaug.profileaug import ProfileAug
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@@ -84,6 +83,15 @@
default=None,
optional=True,
)
+profileaug_choices = ClassChoices(
+ name="profileaug",
+ classes=dict(
+ profileaug=ProfileAug,
+ ),
+ type_check=AbsProfileAug,
+ default=None,
+ optional=True,
+)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
@@ -97,7 +105,8 @@
label_aggregator_choices = ClassChoices(
"label_aggregator",
classes=dict(
- label_aggregator=LabelAggregate
+ label_aggregator=LabelAggregate,
+ label_aggregator_max_pool=LabelAggregateMaxPooling,
),
type_check=torch.nn.Module,
default=None,
@@ -108,7 +117,7 @@
classes=dict(
sond=DiarSondModel,
),
- type_check=AbsESPnetModel,
+ type_check=torch.nn.Module,
default="sond",
)
encoder_choices = ClassChoices(
@@ -189,6 +198,8 @@
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
+ # --profileaug and --profileaug_conf
+ profileaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --label_aggregator and --label_aggregator_conf
@@ -328,15 +339,13 @@
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
- assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
- return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+ return DiarCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
@@ -366,7 +375,6 @@
)
else:
retval = None
- assert check_return_type(retval)
return retval
@classmethod
@@ -385,12 +393,45 @@
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
- assert check_return_type(retval)
return retval
@classmethod
+ def build_optimizers(
+ cls,
+ args: argparse.Namespace,
+ model: torch.nn.Module,
+ ) -> List[torch.optim.Optimizer]:
+ if cls.num_optimizers != 1:
+ raise RuntimeError(
+ "build_optimizers() must be overridden if num_optimizers != 1"
+ )
+ from funasr.tasks.abs_task import optim_classes
+ optim_class = optim_classes.get(args.optim)
+ if optim_class is None:
+ raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
+ else:
+ if (hasattr(model, "model_regularizer_weight") and
+ model.model_regularizer_weight > 0.0 and
+ hasattr(model, "get_regularize_parameters")
+ ):
+ to_regularize_parameters, normal_parameters = model.get_regularize_parameters()
+ logging.info(f"Set weight decay {model.model_regularizer_weight} for parameters: "
+ f"{[name for name, value in to_regularize_parameters]}")
+ module_optim_config = [
+ {"params": [value for name, value in to_regularize_parameters],
+ "weight_decay": model.model_regularizer_weight},
+ {"params": [value for name, value in normal_parameters],
+ "weight_decay": 0.0}
+ ]
+ optim = optim_class(module_optim_config, **args.optim_conf)
+ else:
+ optim = optim_class(model.parameters(), **args.optim_conf)
+
+ optimizers = [optim]
+ return optimizers
+
+ @classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@@ -426,6 +467,13 @@
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
+
+ # 2b. Data augmentation for Profiles
+ if hasattr(args, "profileaug") and args.profileaug is not None:
+ profileaug_class = profileaug_choices.get_class(args.profileaug)
+ profileaug = profileaug_class(**args.profileaug_conf)
+ else:
+ profileaug = None
# 3. Normalization layer
if args.normalize is not None:
@@ -474,6 +522,7 @@
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
+ profileaug=profileaug,
normalize=normalize,
label_aggregator=label_aggregator,
encoder=encoder,
@@ -488,8 +537,8 @@
# 10. Initialize
if args.init is not None:
initialize(model, args.init)
+ logging.info(f"Init model parameters with {args.init}.")
- assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@@ -512,7 +561,6 @@
device: Device type, "cpu", "cuda", or "cuda:N".
"""
- assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
@@ -528,9 +576,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, torch.nn.Module):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {torch.nn.Module.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
@@ -551,7 +599,7 @@
model_dict = torch.load(model_name_pth, map_location=device)
else:
model_dict = cls.convert_tf2torch(model, model_file)
- model.load_state_dict(model_dict)
+ # model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
--
Gitblit v1.9.1