From 38de2af5bf9976d2f14f087d9a0d31991daf6783 Mon Sep 17 00:00:00 2001
From: Zhihao Du <neo.dzh@alibaba-inc.com>
Date: 星期四, 16 三月 2023 19:41:34 +0800
Subject: [PATCH] Merge branch 'main' into dev_dzh
---
funasr/tasks/diar.py | 331 ++++++++++++++++++++++++++++++++++++++++++++++++++++--
1 files changed, 315 insertions(+), 16 deletions(-)
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index bf3ae41..096a5c8 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -20,19 +20,19 @@
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.layers.label_aggregation import LabelAggregate
-from funasr.models.ctc import CTC
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
-from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
-from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
-from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
-from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
-from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
+from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.e2e_diar_sond import DiarSondModel
+from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
+from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
+from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
+from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
+from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
@@ -41,17 +41,13 @@
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.postencoder.hugging_face_transformers_postencoder import (
- HuggingFaceTransformersPostEncoder, # noqa: H301
-)
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.preencoder.linear import LinearProjection
-from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
+from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -70,6 +66,7 @@
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
+ wav_frontend_mel23=WavFrontendMel23,
),
type_check=AbsFrontend,
default="default",
@@ -107,6 +104,7 @@
"model",
classes=dict(
sond=DiarSondModel,
+ eend_ola=DiarEENDOLAModel,
),
type_check=AbsESPnetModel,
default="sond",
@@ -126,6 +124,7 @@
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
ecapa_tdnn=ECAPA_TDNN,
+ eend_ola_transformer=EENDOLATransformerEncoder,
),
type_check=torch.nn.Module,
default="resnet34",
@@ -176,6 +175,15 @@
),
type_check=torch.nn.Module,
default="fsmn",
+)
+# encoder_decoder_attractor is used for EEND-OLA
+encoder_decoder_attractor_choices = ClassChoices(
+ "encoder_decoder_attractor",
+ classes=dict(
+ eda=EncoderDecoderAttractor,
+ ),
+ type_check=torch.nn.Module,
+ default="eda",
)
@@ -545,7 +553,7 @@
if ".bin" in model_name:
model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
else:
- model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
+ model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
if os.path.exists(model_name_pth):
logging.info("model_file is load from pth: {}".format(model_name_pth))
model_dict = torch.load(model_name_pth, map_location=device)
@@ -609,3 +617,294 @@
var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
+
+
+class EENDOLADiarTask(AbsTask):
+ # If you need more than 1 optimizer, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ model_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --speaker_encoder and --speaker_encoder_conf
+ encoder_decoder_attractor_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def add_task_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(description="Task related")
+
+ # NOTE(kamo): add_arguments(..., required=True) can't be used
+ # to provide --print_config mode. Instead of it, do as
+ # required = parser.get_default("required")
+ # required += ["token_list"]
+
+ group.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ group.add_argument(
+ "--split_with_space",
+ type=str2bool,
+ default=True,
+ help="whether to split text using <space>",
+ )
+ group.add_argument(
+ "--seg_dict_file",
+ type=str,
+ default=None,
+ help="seg_dict_file for text processing",
+ )
+ group.add_argument(
+ "--init",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The initialization method",
+ choices=[
+ "chainer",
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ None,
+ ],
+ )
+
+ group.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
+ group = parser.add_argument_group(description="Preprocess related")
+ group.add_argument(
+ "--use_preprocessor",
+ type=str2bool,
+ default=True,
+ help="Apply preprocessing to data or not",
+ )
+ group.add_argument(
+ "--token_type",
+ type=str,
+ default="char",
+ choices=["char"],
+ help="The text will be tokenized in the specified level token",
+ )
+ parser.add_argument(
+ "--speech_volume_normalize",
+ type=float_or_none,
+ default=None,
+ help="Scale the maximum amplitude to the given value.",
+ )
+ parser.add_argument(
+ "--rir_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of rir scp file.",
+ )
+ parser.add_argument(
+ "--rir_apply_prob",
+ type=float,
+ default=1.0,
+ help="THe probability for applying RIR convolution.",
+ )
+ parser.add_argument(
+ "--cmvn_file",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+ parser.add_argument(
+ "--noise_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+ parser.add_argument(
+ "--noise_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability applying Noise adding.",
+ )
+ parser.add_argument(
+ "--noise_db_range",
+ type=str,
+ default="13_15",
+ help="The range of noise decibel level.",
+ )
+
+ for class_choices in cls.class_choices_list:
+ # Append --<name> and --<name>_conf.
+ # e.g. --encoder and --encoder_conf
+ class_choices.add_arguments(group)
+
+ @classmethod
+ def build_collate_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Callable[
+ [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+ Tuple[List[str], Dict[str, torch.Tensor]],
+ ]:
+ assert check_argument_types()
+ # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
+ return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+ @classmethod
+ def build_preprocess_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+ assert check_argument_types()
+ # if args.use_preprocessor:
+ # retval = CommonPreprocessor(
+ # train=train,
+ # token_type=args.token_type,
+ # token_list=args.token_list,
+ # bpemodel=None,
+ # non_linguistic_symbols=None,
+ # text_cleaner=None,
+ # g2p_type=None,
+ # split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+ # seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
+ # # NOTE(kamo): Check attribute existence for backward compatibility
+ # rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+ # rir_apply_prob=args.rir_apply_prob
+ # if hasattr(args, "rir_apply_prob")
+ # else 1.0,
+ # noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+ # noise_apply_prob=args.noise_apply_prob
+ # if hasattr(args, "noise_apply_prob")
+ # else 1.0,
+ # noise_db_range=args.noise_db_range
+ # if hasattr(args, "noise_db_range")
+ # else "13_15",
+ # speech_volume_normalize=args.speech_volume_normalize
+ # if hasattr(args, "rir_scp")
+ # else None,
+ # )
+ # else:
+ # retval = None
+ # assert check_return_type(retval)
+ return None
+
+ @classmethod
+ def required_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ if not inference:
+ retval = ("speech", )
+ else:
+ # Recognition mode
+ retval = ("speech", )
+ return retval
+
+ @classmethod
+ def optional_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ retval = ()
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ assert check_argument_types()
+
+ # 1. frontend
+ if args.input_size is None or args.frontend == "wav_frontend_mel23":
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(**args.encoder_conf)
+
+ # 3. EncoderDecoderAttractor
+ encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
+ encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
+
+ # 9. Build model
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ frontend=frontend,
+ encoder=encoder,
+ encoder_decoder_attractor=encoder_decoder_attractor,
+ **args.model_conf,
+ )
+
+ # 10. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
+
+ # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
+ @classmethod
+ def build_model_from_file(
+ cls,
+ config_file: Union[Path, str] = None,
+ model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ device: str = "cpu",
+ ):
+ """Build model from the files.
+
+ This method is used for inference or fine-tuning.
+
+ Args:
+ config_file: The yaml file saved when training.
+ model_file: The model file saved when training.
+ cmvn_file: The cmvn file for front-end
+ device: Device type, "cpu", "cuda", or "cuda:N".
+
+ """
+ assert check_argument_types()
+ if config_file is None:
+ assert model_file is not None, (
+ "The argument 'model_file' must be provided "
+ "if the argument 'config_file' is not specified."
+ )
+ config_file = Path(model_file).parent / "config.yaml"
+ else:
+ config_file = Path(config_file)
+
+ with config_file.open("r", encoding="utf-8") as f:
+ args = yaml.safe_load(f)
+ args = argparse.Namespace(**args)
+ model = cls.build_model(args)
+ if not isinstance(model, AbsESPnetModel):
+ raise RuntimeError(
+ f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ )
+ if model_file is not None:
+ if device == "cuda":
+ device = f"cuda:{torch.cuda.current_device()}"
+ checkpoint = torch.load(model_file, map_location=device)
+ if "state_dict" in checkpoint.keys():
+ model.load_state_dict(checkpoint["state_dict"])
+ else:
+ model.load_state_dict(checkpoint)
+ model.to(device)
+ return model, args
--
Gitblit v1.9.1