| New file |
| | |
| | | import logging |
| | | |
| | | import torch |
| | | |
| | | from funasr.layers.global_mvn import GlobalMVN |
| | | from funasr.layers.label_aggregation import LabelAggregate |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel |
| | | from funasr.models.e2e_diar_sond import DiarSondModel |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.models.encoder.data2vec_encoder import Data2VecEncoder |
| | | from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN |
| | | from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer |
| | | from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder |
| | | from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder |
| | | from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar |
| | | from funasr.models.encoder.rnn_encoder import RNNEncoder |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt |
| | | from funasr.models.encoder.transformer_encoder import TransformerEncoder |
| | | from funasr.models.frontend.default import DefaultFrontend |
| | | from funasr.models.frontend.fused import FusedFrontends |
| | | from funasr.models.frontend.s3prl import S3prlFrontend |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.frontend.wav_frontend import WavFrontendMel23 |
| | | from funasr.models.frontend.windowing import SlidingWindow |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | | from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder |
| | | from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor |
| | | from funasr.torch_utils.initialize import initialize |
| | | from funasr.train.class_choices import ClassChoices |
| | | |
| | | frontend_choices = ClassChoices( |
| | | name="frontend", |
| | | classes=dict( |
| | | default=DefaultFrontend, |
| | | sliding_window=SlidingWindow, |
| | | s3prl=S3prlFrontend, |
| | | fused=FusedFrontends, |
| | | wav_frontend=WavFrontend, |
| | | wav_frontend_mel23=WavFrontendMel23, |
| | | ), |
| | | default="default", |
| | | ) |
| | | specaug_choices = ClassChoices( |
| | | name="specaug", |
| | | classes=dict( |
| | | specaug=SpecAug, |
| | | specaug_lfr=SpecAugLFR, |
| | | ), |
| | | default=None, |
| | | optional=True, |
| | | ) |
| | | normalize_choices = ClassChoices( |
| | | "normalize", |
| | | classes=dict( |
| | | global_mvn=GlobalMVN, |
| | | utterance_mvn=UtteranceMVN, |
| | | ), |
| | | default=None, |
| | | optional=True, |
| | | ) |
| | | label_aggregator_choices = ClassChoices( |
| | | "label_aggregator", |
| | | classes=dict( |
| | | label_aggregator=LabelAggregate |
| | | ), |
| | | default=None, |
| | | optional=True, |
| | | ) |
| | | model_choices = ClassChoices( |
| | | "model", |
| | | classes=dict( |
| | | sond=DiarSondModel, |
| | | eend_ola=DiarEENDOLAModel, |
| | | ), |
| | | default="sond", |
| | | ) |
| | | encoder_choices = ClassChoices( |
| | | "encoder", |
| | | classes=dict( |
| | | conformer=ConformerEncoder, |
| | | transformer=TransformerEncoder, |
| | | rnn=RNNEncoder, |
| | | sanm=SANMEncoder, |
| | | san=SelfAttentionEncoder, |
| | | fsmn=FsmnEncoder, |
| | | conv=ConvEncoder, |
| | | resnet34=ResNet34Diar, |
| | | resnet34_sp_l2reg=ResNet34SpL2RegDiar, |
| | | sanm_chunk_opt=SANMEncoderChunkOpt, |
| | | data2vec_encoder=Data2VecEncoder, |
| | | ecapa_tdnn=ECAPA_TDNN, |
| | | eend_ola_transformer=EENDOLATransformerEncoder, |
| | | ), |
| | | default="resnet34", |
| | | ) |
| | | speaker_encoder_choices = ClassChoices( |
| | | "speaker_encoder", |
| | | classes=dict( |
| | | conformer=ConformerEncoder, |
| | | transformer=TransformerEncoder, |
| | | rnn=RNNEncoder, |
| | | sanm=SANMEncoder, |
| | | san=SelfAttentionEncoder, |
| | | fsmn=FsmnEncoder, |
| | | conv=ConvEncoder, |
| | | sanm_chunk_opt=SANMEncoderChunkOpt, |
| | | data2vec_encoder=Data2VecEncoder, |
| | | ), |
| | | default=None, |
| | | optional=True |
| | | ) |
| | | cd_scorer_choices = ClassChoices( |
| | | "cd_scorer", |
| | | classes=dict( |
| | | san=SelfAttentionEncoder, |
| | | ), |
| | | default=None, |
| | | optional=True, |
| | | ) |
| | | ci_scorer_choices = ClassChoices( |
| | | "ci_scorer", |
| | | classes=dict( |
| | | dot=DotScorer, |
| | | cosine=CosScorer, |
| | | conv=ConvEncoder, |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default=None, |
| | | optional=True, |
| | | ) |
| | | # decoder is used for output (e.g. post_net in SOND) |
| | | decoder_choices = ClassChoices( |
| | | "decoder", |
| | | classes=dict( |
| | | rnn=RNNEncoder, |
| | | fsmn=FsmnEncoder, |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default="fsmn", |
| | | ) |
| | | # encoder_decoder_attractor is used for EEND-OLA |
| | | encoder_decoder_attractor_choices = ClassChoices( |
| | | "encoder_decoder_attractor", |
| | | classes=dict( |
| | | eda=EncoderDecoderAttractor, |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default="eda", |
| | | ) |
| | | class_choices_list = [ |
| | | # --frontend and --frontend_conf |
| | | frontend_choices, |
| | | # --specaug and --specaug_conf |
| | | specaug_choices, |
| | | # --normalize and --normalize_conf |
| | | normalize_choices, |
| | | # --label_aggregator and --label_aggregator_conf |
| | | label_aggregator_choices, |
| | | # --model and --model_conf |
| | | model_choices, |
| | | # --encoder and --encoder_conf |
| | | encoder_choices, |
| | | # --speaker_encoder and --speaker_encoder_conf |
| | | speaker_encoder_choices, |
| | | # --cd_scorer and cd_scorer_conf |
| | | cd_scorer_choices, |
| | | # --ci_scorer and ci_scorer_conf |
| | | ci_scorer_choices, |
| | | # --decoder and --decoder_conf |
| | | decoder_choices, |
| | | # --eda and --eda_conf |
| | | encoder_decoder_attractor_choices, |
| | | ] |
| | | |
| | | |
| | | def build_diar_model(args): |
| | | # token_list |
| | | if args.token_list is not None: |
| | | with open(args.token_list) as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | args.token_list = list(token_list) |
| | | vocab_size = len(token_list) |
| | | logging.info(f"Vocabulary size: {vocab_size}") |
| | | else: |
| | | vocab_size = None |
| | | |
| | | # frontend |
| | | if args.input_size is None: |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | if args.frontend == 'wav_frontend': |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | | else: |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | args.frontend = None |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | # encoder |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(input_size=input_size, **args.encoder_conf) |
| | | |
| | | if args.model_name == "sond": |
| | | # data augmentation for spectrogram |
| | | if args.specaug is not None: |
| | | specaug_class = specaug_choices.get_class(args.specaug) |
| | | specaug = specaug_class(**args.specaug_conf) |
| | | else: |
| | | specaug = None |
| | | |
| | | # normalization layer |
| | | if args.normalize is not None: |
| | | normalize_class = normalize_choices.get_class(args.normalize) |
| | | normalize = normalize_class(**args.normalize_conf) |
| | | else: |
| | | normalize = None |
| | | |
| | | # speaker encoder |
| | | if getattr(args, "speaker_encoder", None) is not None: |
| | | speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder) |
| | | speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf) |
| | | else: |
| | | speaker_encoder = None |
| | | |
| | | # ci scorer |
| | | if getattr(args, "ci_scorer", None) is not None: |
| | | ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer) |
| | | ci_scorer = ci_scorer_class(**args.ci_scorer_conf) |
| | | else: |
| | | ci_scorer = None |
| | | |
| | | # cd scorer |
| | | if getattr(args, "cd_scorer", None) is not None: |
| | | cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer) |
| | | cd_scorer = cd_scorer_class(**args.cd_scorer_conf) |
| | | else: |
| | | cd_scorer = None |
| | | |
| | | # decoder |
| | | decoder_class = decoder_choices.get_class(args.decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder.output_size(), |
| | | **args.decoder_conf, |
| | | ) |
| | | |
| | | # logger aggregator |
| | | if getattr(args, "label_aggregator", None) is not None: |
| | | label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator) |
| | | label_aggregator = label_aggregator_class(**args.label_aggregator_conf) |
| | | else: |
| | | label_aggregator = None |
| | | |
| | | model_class = model_choices.get_class(args.model) |
| | | model = model_class( |
| | | vocab_size=vocab_size, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | label_aggregator=label_aggregator, |
| | | encoder=encoder, |
| | | speaker_encoder=speaker_encoder, |
| | | ci_scorer=ci_scorer, |
| | | cd_scorer=cd_scorer, |
| | | decoder=decoder, |
| | | token_list=token_list, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | elif args.model_name == "eend_ola": |
| | | # encoder-decoder attractor |
| | | encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor) |
| | | encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf) |
| | | |
| | | # 9. Build model |
| | | model_class = model_choices.get_class(args.model) |
| | | model = model_class( |
| | | frontend=frontend, |
| | | encoder=encoder, |
| | | encoder_decoder_attractor=encoder_decoder_attractor, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | else: |
| | | raise NotImplementedError("Not supported model: {}".format(args.model)) |
| | | |
| | | # 10. Initialize |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| | | return model |