| | |
| | | import torch |
| | | |
| | | from funasr.layers.global_mvn import GlobalMVN |
| | | from funasr.layers.label_aggregation import LabelAggregate |
| | | from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel |
| | | from funasr.models.e2e_diar_sond import DiarSondModel |
| | |
| | | from funasr.models.frontend.windowing import SlidingWindow |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | | from funasr.models.specaug.abs_profileaug import AbsProfileAug |
| | | from funasr.models.specaug.profileaug import ProfileAug |
| | | from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder |
| | | from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor |
| | | from funasr.torch_utils.initialize import initialize |
| | |
| | | default=None, |
| | | optional=True, |
| | | ) |
| | | profileaug_choices = ClassChoices( |
| | | name="profileaug", |
| | | classes=dict( |
| | | profileaug=ProfileAug, |
| | | ), |
| | | type_check=AbsProfileAug, |
| | | default=None, |
| | | optional=True, |
| | | ) |
| | | normalize_choices = ClassChoices( |
| | | "normalize", |
| | | classes=dict( |
| | |
| | | label_aggregator_choices = ClassChoices( |
| | | "label_aggregator", |
| | | classes=dict( |
| | | label_aggregator=LabelAggregate |
| | | label_aggregator=LabelAggregate, |
| | | label_aggregator_max_pool=LabelAggregateMaxPooling, |
| | | ), |
| | | default=None, |
| | | optional=True, |
| | |
| | | frontend_choices, |
| | | # --specaug and --specaug_conf |
| | | specaug_choices, |
| | | # --profileaug and --profileaug_conf |
| | | profileaug_choices, |
| | | # --normalize and --normalize_conf |
| | | normalize_choices, |
| | | # --label_aggregator and --label_aggregator_conf |
| | |
| | | |
| | | def build_diar_model(args): |
| | | # token_list |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | if args.token_list is not None: |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | |
| | | # Overwriting token_list to keep it as "portable". |
| | | args.token_list = list(token_list) |
| | | elif isinstance(args.token_list, (tuple, list)): |
| | | token_list = list(args.token_list) |
| | | # Overwriting token_list to keep it as "portable". |
| | | args.token_list = list(token_list) |
| | | elif isinstance(args.token_list, (tuple, list)): |
| | | token_list = list(args.token_list) |
| | | else: |
| | | raise RuntimeError("token_list must be str or list") |
| | | vocab_size = len(token_list) |
| | | logging.info(f"Vocabulary size: {vocab_size}") |
| | | else: |
| | | raise RuntimeError("token_list must be str or list") |
| | | vocab_size = len(token_list) |
| | | logging.info(f"Vocabulary size: {vocab_size}") |
| | | token_list = None |
| | | vocab_size = None |
| | | |
| | | # frontend |
| | | if args.input_size is None: |
| | |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | # encoder |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(input_size=input_size, **args.encoder_conf) |
| | | if args.model == "sond": |
| | | # encoder |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(input_size=input_size ,**args.encoder_conf) |
| | | |
| | | if args.model_name == "sond": |
| | | # data augmentation for spectrogram |
| | | if args.specaug is not None: |
| | | specaug_class = specaug_choices.get_class(args.specaug) |
| | | specaug = specaug_class(**args.specaug_conf) |
| | | else: |
| | | specaug = None |
| | | |
| | | # Data augmentation for Profiles |
| | | if hasattr(args, "profileaug") and args.profileaug is not None: |
| | | profileaug_class = profileaug_choices.get_class(args.profileaug) |
| | | profileaug = profileaug_class(**args.profileaug_conf) |
| | | else: |
| | | profileaug = None |
| | | |
| | | # normalization layer |
| | | if args.normalize is not None: |
| | |
| | | |
| | | # decoder |
| | | decoder_class = decoder_choices.get_class(args.decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder.output_size(), |
| | | **args.decoder_conf, |
| | | ) |
| | | decoder = decoder_class(**args.decoder_conf) |
| | | |
| | | # logger aggregator |
| | | if getattr(args, "label_aggregator", None) is not None: |
| | |
| | | vocab_size=vocab_size, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | profileaug=profileaug, |
| | | normalize=normalize, |
| | | label_aggregator=label_aggregator, |
| | | encoder=encoder, |
| | |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | elif args.model_name == "eend_ola": |
| | | elif args.model == "eend_ola": |
| | | # encoder |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(**args.encoder_conf) |
| | | |
| | | # encoder-decoder attractor |
| | | encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor) |
| | | encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf) |