From 27fddb4982855d80b850d66f019d20ec19d8d196 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 15 六月 2023 16:41:37 +0800
Subject: [PATCH] update repo
---
funasr/build_utils/build_model.py | 5
funasr/build_utils/build_args.py | 11 ++
funasr/bin/sv_infer.py | 3
funasr/build_utils/build_sv_model.py | 258 +++++++++++++++++++++++++++++++++++++++++++++++++++
funasr/build_utils/build_model_from_file.py | 19 +++
5 files changed, 291 insertions(+), 5 deletions(-)
diff --git a/funasr/bin/sv_infer.py b/funasr/bin/sv_infer.py
index fd0d666..6e861da 100755
--- a/funasr/bin/sv_infer.py
+++ b/funasr/bin/sv_infer.py
@@ -50,7 +50,8 @@
model_file=sv_model_file,
cmvn_file=None,
device=device,
- task_name="sv"
+ task_name="sv",
+ mode="sv",
)
logging.info("sv_model: {}".format(sv_model))
logging.info("model parameter number: {}".format(statistic_model_parameters(sv_model)))
diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py
index 517c85b..cc43064 100644
--- a/funasr/build_utils/build_args.py
+++ b/funasr/build_utils/build_args.py
@@ -81,6 +81,17 @@
for class_choices in class_choices_list:
class_choices.add_arguments(task_parser)
+ elif args.task_name == "sv":
+ from funasr.build_utils.build_sv_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+ task_parser.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
index 13a6faa..be8f910 100644
--- a/funasr/build_utils/build_model.py
+++ b/funasr/build_utils/build_model.py
@@ -1,9 +1,10 @@
from funasr.build_utils.build_asr_model import build_asr_model
+from funasr.build_utils.build_diar_model import build_diar_model
from funasr.build_utils.build_lm_model import build_lm_model
from funasr.build_utils.build_pretrain_model import build_pretrain_model
from funasr.build_utils.build_punc_model import build_punc_model
+from funasr.build_utils.build_sv_model import build_sv_model
from funasr.build_utils.build_vad_model import build_vad_model
-from funasr.build_utils.build_diar_model import build_diar_model
def build_model(args):
@@ -19,6 +20,8 @@
model = build_vad_model(args)
elif args.task_name == "diar":
model = build_diar_model(args)
+ elif args.task_name == "sv":
+ model = build_sv_model(args)
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))
diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py
index 2eadae4..53eafc1 100644
--- a/funasr/build_utils/build_model_from_file.py
+++ b/funasr/build_utils/build_model_from_file.py
@@ -87,7 +87,7 @@
ckpt,
mode,
):
- assert mode == "paraformer" or mode == "uniasr" or mode == "sond"
+ assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv"
logging.info("start convert tf model to torch model")
from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
var_dict_tf = load_tf_dict(ckpt)
@@ -128,7 +128,7 @@
# bias_encoder
var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
- else:
+ elif "mode" == "sond":
if model.encoder is not None:
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
@@ -148,8 +148,21 @@
if model.decoder is not None:
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
var_dict_torch_update.update(var_dict_torch_update_local)
+ else:
+ # speech encoder
+ var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # pooling layer
+ var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+ # decoder
+ var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
+ var_dict_torch_update.update(var_dict_torch_update_local)
+
+ return var_dict_torch_update
return var_dict_torch_update
+
def fileter_model_dict(src_dict: dict, dest_dict: dict):
from collections import OrderedDict
@@ -162,4 +175,4 @@
for key, value in dest_dict.items():
if key not in new_dict:
logging.warning("{} is missed in checkpoint.".format(key))
- return new_dict
\ No newline at end of file
+ return new_dict
diff --git a/funasr/build_utils/build_sv_model.py b/funasr/build_utils/build_sv_model.py
new file mode 100644
index 0000000..c0f1ae8
--- /dev/null
+++ b/funasr/build_utils/build_sv_model.py
@@ -0,0 +1,258 @@
+import logging
+
+import torch
+from typeguard import check_return_type
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.base_model import FunASRModel
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.sv_decoder import DenseDecoder
+from funasr.models.e2e_sv import ESPnetSVModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.pooling.statistic_pooling import StatisticPooling
+from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.postencoder.hugging_face_transformers_postencoder import (
+ HuggingFaceTransformersPostEncoder, # noqa: H301
+)
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.linear import LinearProjection
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ ),
+ type_check=AbsFrontend,
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(
+ specaug=SpecAug,
+ ),
+ type_check=AbsSpecAug,
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ type_check=AbsNormalize,
+ default=None,
+ optional=True,
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ espnet=ESPnetSVModel,
+ ),
+ type_check=FunASRModel,
+ default="espnet",
+)
+preencoder_choices = ClassChoices(
+ name="preencoder",
+ classes=dict(
+ sinc=LightweightSincConvs,
+ linear=LinearProjection,
+ ),
+ type_check=AbsPreEncoder,
+ default=None,
+ optional=True,
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ resnet34=ResNet34,
+ resnet34_sp_l2reg=ResNet34_SP_L2Reg,
+ rnn=RNNEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="resnet34",
+)
+postencoder_choices = ClassChoices(
+ name="postencoder",
+ classes=dict(
+ hugging_face_transformers=HuggingFaceTransformersPostEncoder,
+ ),
+ type_check=AbsPostEncoder,
+ default=None,
+ optional=True,
+)
+pooling_choices = ClassChoices(
+ name="pooling_type",
+ classes=dict(
+ statistic=StatisticPooling,
+ ),
+ type_check=torch.nn.Module,
+ default="statistic",
+)
+decoder_choices = ClassChoices(
+ "decoder",
+ classes=dict(
+ dense=DenseDecoder,
+ ),
+ type_check=AbsDecoder,
+ default="dense",
+)
+
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --preencoder and --preencoder_conf
+ preencoder_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --postencoder and --postencoder_conf
+ postencoder_choices,
+ # --pooling and --pooling_conf
+ pooling_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+]
+
+
+def build_sv_model(args):
+ # token_list
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Speaker number: {vocab_size}")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Pre-encoder input block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ if getattr(args, "preencoder", None) is not None:
+ preencoder_class = preencoder_choices.get_class(args.preencoder)
+ preencoder = preencoder_class(**args.preencoder_conf)
+ input_size = preencoder.output_size()
+ else:
+ preencoder = None
+
+ # 5. Encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+ # 6. Post-encoder block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ encoder_output_size = encoder.output_size()
+ if getattr(args, "postencoder", None) is not None:
+ postencoder_class = postencoder_choices.get_class(args.postencoder)
+ postencoder = postencoder_class(
+ input_size=encoder_output_size, **args.postencoder_conf
+ )
+ encoder_output_size = postencoder.output_size()
+ else:
+ postencoder = None
+
+ # 7. Pooling layer
+ pooling_class = pooling_choices.get_class(args.pooling_type)
+ pooling_dim = (2, 3)
+ eps = 1e-12
+ if hasattr(args, "pooling_type_conf"):
+ if "pooling_dim" in args.pooling_type_conf:
+ pooling_dim = args.pooling_type_conf["pooling_dim"]
+ if "eps" in args.pooling_type_conf:
+ eps = args.pooling_type_conf["eps"]
+ pooling_layer = pooling_class(
+ pooling_dim=pooling_dim,
+ eps=eps,
+ )
+ if args.pooling_type == "statistic":
+ encoder_output_size *= 2
+
+ # 8. Decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **args.decoder_conf,
+ )
+
+ # 7. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("espnet")
+ model = model_class(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ pooling_layer=pooling_layer,
+ decoder=decoder,
+ **args.model_conf,
+ )
+
+ # FIXME(kamo): Should be done in model?
+ # 8. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
--
Gitblit v1.9.1