From 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 10 十月 2023 17:17:29 +0800
Subject: [PATCH] v0.8.0
---
funasr/tasks/diar.py | 130 +++++++++++++++++++++++++++++++++----------
1 files changed, 100 insertions(+), 30 deletions(-)
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 73c51e3..084b971 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -13,17 +13,15 @@
import numpy as np
import torch
import yaml
-from typeguard import check_argument_types
-from typeguard import check_return_type
-from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.collate_fn import DiarCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.layers.label_aggregation import LabelAggregate
+from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
from funasr.models.ctc import CTC
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
@@ -52,9 +50,10 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.models.specaug.abs_profileaug import AbsProfileAug
+from funasr.models.specaug.profileaug import ProfileAug
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@@ -84,6 +83,15 @@
default=None,
optional=True,
)
+profileaug_choices = ClassChoices(
+ name="profileaug",
+ classes=dict(
+ profileaug=ProfileAug,
+ ),
+ type_check=AbsProfileAug,
+ default=None,
+ optional=True,
+)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
@@ -97,7 +105,8 @@
label_aggregator_choices = ClassChoices(
"label_aggregator",
classes=dict(
- label_aggregator=LabelAggregate
+ label_aggregator=LabelAggregate,
+ label_aggregator_max_pool=LabelAggregateMaxPooling,
),
type_check=torch.nn.Module,
default=None,
@@ -108,7 +117,7 @@
classes=dict(
sond=DiarSondModel,
),
- type_check=AbsESPnetModel,
+ type_check=torch.nn.Module,
default="sond",
)
encoder_choices = ClassChoices(
@@ -122,6 +131,7 @@
fsmn=FsmnEncoder,
conv=ConvEncoder,
resnet34=ResNet34Diar,
+ resnet34_sp_l2reg=ResNet34SpL2RegDiar,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
ecapa_tdnn=ECAPA_TDNN,
@@ -160,6 +170,7 @@
classes=dict(
dot=DotScorer,
cosine=CosScorer,
+ conv=ConvEncoder,
),
type_check=torch.nn.Module,
default=None,
@@ -187,6 +198,8 @@
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
+ # --profileaug and --profileaug_conf
+ profileaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --label_aggregator and --label_aggregator_conf
@@ -326,15 +339,13 @@
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
- assert check_argument_types()
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
- return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+ return DiarCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
@@ -364,7 +375,6 @@
)
else:
retval = None
- assert check_return_type(retval)
return retval
@classmethod
@@ -383,12 +393,45 @@
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ()
- assert check_return_type(retval)
return retval
@classmethod
+ def build_optimizers(
+ cls,
+ args: argparse.Namespace,
+ model: torch.nn.Module,
+ ) -> List[torch.optim.Optimizer]:
+ if cls.num_optimizers != 1:
+ raise RuntimeError(
+ "build_optimizers() must be overridden if num_optimizers != 1"
+ )
+ from funasr.tasks.abs_task import optim_classes
+ optim_class = optim_classes.get(args.optim)
+ if optim_class is None:
+ raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
+ else:
+ if (hasattr(model, "model_regularizer_weight") and
+ model.model_regularizer_weight > 0.0 and
+ hasattr(model, "get_regularize_parameters")
+ ):
+ to_regularize_parameters, normal_parameters = model.get_regularize_parameters()
+ logging.info(f"Set weight decay {model.model_regularizer_weight} for parameters: "
+ f"{[name for name, value in to_regularize_parameters]}")
+ module_optim_config = [
+ {"params": [value for name, value in to_regularize_parameters],
+ "weight_decay": model.model_regularizer_weight},
+ {"params": [value for name, value in normal_parameters],
+ "weight_decay": 0.0}
+ ]
+ optim = optim_class(module_optim_config, **args.optim_conf)
+ else:
+ optim = optim_class(model.parameters(), **args.optim_conf)
+
+ optimizers = [optim]
+ return optimizers
+
+ @classmethod
def build_model(cls, args: argparse.Namespace):
- assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
@@ -424,6 +467,13 @@
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
+
+ # 2b. Data augmentation for Profiles
+ if hasattr(args, "profileaug") and args.profileaug is not None:
+ profileaug_class = profileaug_choices.get_class(args.profileaug)
+ profileaug = profileaug_class(**args.profileaug_conf)
+ else:
+ profileaug = None
# 3. Normalization layer
if args.normalize is not None:
@@ -472,6 +522,7 @@
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
+ profileaug=profileaug,
normalize=normalize,
label_aggregator=label_aggregator,
encoder=encoder,
@@ -486,8 +537,8 @@
# 10. Initialize
if args.init is not None:
initialize(model, args.init)
+ logging.info(f"Init model parameters with {args.init}.")
- assert check_return_type(model)
return model
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@@ -497,7 +548,7 @@
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
+ device: Union[str, torch.device] = "cpu",
):
"""Build model from the files.
@@ -510,7 +561,6 @@
device: Device type, "cpu", "cuda", or "cuda:N".
"""
- assert check_argument_types()
if config_file is None:
assert model_file is not None, (
"The argument 'model_file' must be provided "
@@ -526,9 +576,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, torch.nn.Module):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {torch.nn.Module.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
@@ -549,15 +599,30 @@
model_dict = torch.load(model_name_pth, map_location=device)
else:
model_dict = cls.convert_tf2torch(model, model_file)
- model.load_state_dict(model_dict)
+ # model.load_state_dict(model_dict)
else:
model_dict = torch.load(model_file, map_location=device)
+ model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
model.load_state_dict(model_dict)
if model_name_pth is not None and not os.path.exists(model_name_pth):
torch.save(model_dict, model_name_pth)
logging.info("model_file is saved to pth: {}".format(model_name_pth))
return model, args
+
+ @classmethod
+ def fileter_model_dict(cls, src_dict: dict, dest_dict: dict):
+ from collections import OrderedDict
+ new_dict = OrderedDict()
+ for key, value in src_dict.items():
+ if key in dest_dict:
+ new_dict[key] = value
+ else:
+ logging.info("{} is no longer needed in this model.".format(key))
+ for key, value in dest_dict.items():
+ if key not in new_dict:
+ logging.warning("{} is missed in checkpoint.".format(key))
+ return new_dict
@classmethod
def convert_tf2torch(
@@ -571,19 +636,24 @@
var_dict_torch = model.state_dict()
var_dict_torch_update = dict()
# speech encoder
- var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.encoder is not None:
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# speaker encoder
- var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.speaker_encoder is not None:
+ var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# cd scorer
- var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.cd_scorer is not None:
+ var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# ci scorer
- var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.ci_scorer is not None:
+ var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
# decoder
- var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
- var_dict_torch_update.update(var_dict_torch_update_local)
+ if model.decoder is not None:
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
--
Gitblit v1.9.1