speech_asr
2023-04-20 200d1ede05e6bc41ef1da6debf7b86df84995fb5
update
1个文件已修改
60 ■■■■■ 已修改文件
funasr/utils/build_asr_model.py 60 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_asr_model.py
@@ -40,6 +40,7 @@
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
frontend_choices = ClassChoices(
@@ -195,6 +196,7 @@
        stride_conv_choices,
    ]
def build_asr_model(args):
    # token_list
    if args.token_list is not None:
@@ -270,6 +272,7 @@
        # predictor
        predictor_class = predictor_choices.get_class(args.predictor)
        predictor = predictor_class(**args.predictor_conf)
        model_class = model_choices.get_class(args.model)
        model = model_class(
            vocab_size=vocab_size,
@@ -283,4 +286,59 @@
            predictor=predictor,
            **args.model_conf,
        )
    elif
    elif args.model == "uniasr":
        # stride_conv
        stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
        stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
                                        odim=input_size + encoder.output_size())
        stride_conv_output_size = stride_conv.output_size()
        # encoder2
        encoder_class2 = encoder_choices2.get_class(args.encoder2)
        encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
        # decoder2
        decoder_class2 = decoder_choices2.get_class(args.decoder2)
        decoder2 = decoder_class2(
            vocab_size=vocab_size,
            encoder_output_size=encoder2.output_size(),
            **args.decoder2_conf,
        )
        # ctc2
        ctc2 = CTC(
            odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
        )
        # predictor
        predictor_class = predictor_choices.get_class(args.predictor)
        predictor = predictor_class(**args.predictor_conf)
        # predictor2
        predictor_class = predictor_choices2.get_class(args.predictor2)
        predictor2 = predictor_class(**args.predictor2_conf)
        model_class = model_choices.get_class(args.model)
        model = model_class(
            vocab_size=vocab_size,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            encoder=encoder,
            decoder=decoder,
            ctc=ctc,
            token_list=token_list,
            predictor=predictor,
            ctc2=ctc2,
            encoder2=encoder2,
            decoder2=decoder2,
            predictor2=predictor2,
            stride_conv=stride_conv,
            **args.model_conf,
        )
    else:
        raise NotImplementedError("Not supported model: {}".format(args.model))
    if args.init is not None:
        initialize(model, args.init)