| | |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr.datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.collate_fn import DiarCollateFn |
| | | from funasr.datasets.preprocessor import CommonPreprocessor |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.layers.global_mvn import GlobalMVN |
| | | from funasr.layers.utterance_mvn import UtteranceMVN |
| | | from funasr.layers.label_aggregation import LabelAggregate |
| | | from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar |
| | | from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN |
| | |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | | from funasr.models.specaug.abs_profileaug import AbsProfileAug |
| | | from funasr.models.specaug.profileaug import ProfileAug |
| | | from funasr.tasks.abs_task import AbsTask |
| | | from funasr.torch_utils.initialize import initialize |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.train.class_choices import ClassChoices |
| | | from funasr.train.trainer import Trainer |
| | | from funasr.utils.types import float_or_none |
| | |
| | | default=None, |
| | | optional=True, |
| | | ) |
| | | profileaug_choices = ClassChoices( |
| | | name="profileaug", |
| | | classes=dict( |
| | | profileaug=ProfileAug, |
| | | ), |
| | | type_check=AbsProfileAug, |
| | | default=None, |
| | | optional=True, |
| | | ) |
| | | normalize_choices = ClassChoices( |
| | | "normalize", |
| | | classes=dict( |
| | |
| | | label_aggregator_choices = ClassChoices( |
| | | "label_aggregator", |
| | | classes=dict( |
| | | label_aggregator=LabelAggregate |
| | | label_aggregator=LabelAggregate, |
| | | label_aggregator_max_pool=LabelAggregateMaxPooling, |
| | | ), |
| | | type_check=torch.nn.Module, |
| | | default=None, |
| | |
| | | classes=dict( |
| | | sond=DiarSondModel, |
| | | ), |
| | | type_check=AbsESPnetModel, |
| | | type_check=torch.nn.Module, |
| | | default="sond", |
| | | ) |
| | | encoder_choices = ClassChoices( |
| | |
| | | frontend_choices, |
| | | # --specaug and --specaug_conf |
| | | specaug_choices, |
| | | # --profileaug and --profileaug_conf |
| | | profileaug_choices, |
| | | # --normalize and --normalize_conf |
| | | normalize_choices, |
| | | # --label_aggregator and --label_aggregator_conf |
| | |
| | | ]: |
| | | assert check_argument_types() |
| | | # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol |
| | | return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) |
| | | return DiarCollateFn(float_pad_value=0.0, int_pad_value=-1) |
| | | |
| | | @classmethod |
| | | def build_preprocess_fn( |
| | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_optimizers( |
| | | cls, |
| | | args: argparse.Namespace, |
| | | model: torch.nn.Module, |
| | | ) -> List[torch.optim.Optimizer]: |
| | | if cls.num_optimizers != 1: |
| | | raise RuntimeError( |
| | | "build_optimizers() must be overridden if num_optimizers != 1" |
| | | ) |
| | | from funasr.tasks.abs_task import optim_classes |
| | | optim_class = optim_classes.get(args.optim) |
| | | if optim_class is None: |
| | | raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}") |
| | | else: |
| | | if (hasattr(model, "model_regularizer_weight") and |
| | | model.model_regularizer_weight > 0.0 and |
| | | hasattr(model, "get_regularize_parameters") |
| | | ): |
| | | to_regularize_parameters, normal_parameters = model.get_regularize_parameters() |
| | | logging.info(f"Set weight decay {model.model_regularizer_weight} for parameters: " |
| | | f"{[name for name, value in to_regularize_parameters]}") |
| | | module_optim_config = [ |
| | | {"params": [value for name, value in to_regularize_parameters], |
| | | "weight_decay": model.model_regularizer_weight}, |
| | | {"params": [value for name, value in normal_parameters], |
| | | "weight_decay": 0.0} |
| | | ] |
| | | optim = optim_class(module_optim_config, **args.optim_conf) |
| | | else: |
| | | optim = optim_class(model.parameters(), **args.optim_conf) |
| | | |
| | | optimizers = [optim] |
| | | return optimizers |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | |
| | | specaug = specaug_class(**args.specaug_conf) |
| | | else: |
| | | specaug = None |
| | | |
| | | # 2b. Data augmentation for Profiles |
| | | if hasattr(args, "profileaug") and args.profileaug is not None: |
| | | profileaug_class = profileaug_choices.get_class(args.profileaug) |
| | | profileaug = profileaug_class(**args.profileaug_conf) |
| | | else: |
| | | profileaug = None |
| | | |
| | | # 3. Normalization layer |
| | | if args.normalize is not None: |
| | |
| | | vocab_size=vocab_size, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | profileaug=profileaug, |
| | | normalize=normalize, |
| | | label_aggregator=label_aggregator, |
| | | encoder=encoder, |
| | |
| | | # 10. Initialize |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | logging.info(f"Init model parameters with {args.init}.") |
| | | |
| | | assert check_return_type(model) |
| | | return model |
| | |
| | | args["cmvn_file"] = cmvn_file |
| | | args = argparse.Namespace(**args) |
| | | model = cls.build_model(args) |
| | | if not isinstance(model, AbsESPnetModel): |
| | | if not isinstance(model, torch.nn.Module): |
| | | raise RuntimeError( |
| | | f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
| | | f"model must inherit {torch.nn.Module.__name__}, but got {type(model)}" |
| | | ) |
| | | model.to(device) |
| | | model_dict = dict() |
| | |
| | | model_dict = torch.load(model_name_pth, map_location=device) |
| | | else: |
| | | model_dict = cls.convert_tf2torch(model, model_file) |
| | | model.load_state_dict(model_dict) |
| | | # model.load_state_dict(model_dict) |
| | | else: |
| | | model_dict = torch.load(model_file, map_location=device) |
| | | model_dict = cls.fileter_model_dict(model_dict, model.state_dict()) |