speech_asr
2023-04-20 eac9f111b502e4581b14dc718731bf7dc1c7d5f6
funasr/utils/build_pretrain_model.py
@@ -57,39 +57,39 @@
def build_pretrain_model(args):
    # frontend
    if args.input_size is None:
        frontend_class = frontend_choices.get_class(args.frontend)
        frontend = frontend_class(**args.frontend_conf)
        input_size = frontend.output_size()
    else:
        args.frontend = None
        args.frontend_conf = {}
        frontend = None
        input_size = args.input_size
    # data augmentation for spectrogram
    if args.specaug is not None:
        specaug_class = specaug_choices.get_class(args.specaug)
        specaug = specaug_class(**args.specaug_conf)
    else:
        specaug = None
    # normalization layer
    if args.normalize is not None:
        normalize_class = normalize_choices.get_class(args.normalize)
        normalize = normalize_class(**args.normalize_conf)
    else:
        normalize = None
    # encoder
    encoder_class = encoder_choices.get_class(args.encoder)
    encoder = encoder_class(
        input_size=input_size,
        **args.encoder_conf,
    )
    if args.model_name == "data2vec":
        # frontend
        if args.input_size is None:
            frontend_class = frontend_choices.get_class(args.frontend)
            frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size
        # data augmentation for spectrogram
        if args.specaug is not None:
            specaug_class = specaug_choices.get_class(args.specaug)
            specaug = specaug_class(**args.specaug_conf)
        else:
            specaug = None
        # normalization layer
        if args.normalize is not None:
            normalize_class = normalize_choices.get_class(args.normalize)
            normalize = normalize_class(**args.normalize_conf)
        else:
            normalize = None
        # encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(
            input_size=input_size,
            **args.encoder_conf,
        )
        model_class = model_choices.get_class("data2vec")
        model = model_class(
            frontend=frontend,
@@ -97,9 +97,11 @@
            normalize=normalize,
            encoder=encoder,
        )
    else:
        raise NotImplementedError("Not supported model: {}".format(args.model))
        # 7. Initialize
        if args.init is not None:
            initialize(model, args.init)
    # initialize
    if args.init is not None:
        initialize(model, args.init)
        return model
    return model