speech_asr
2023-04-20 200d1ede05e6bc41ef1da6debf7b86df84995fb5
update
1个文件已修改
106 ■■■■ 已修改文件
funasr/utils/build_asr_model.py 106 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_asr_model.py
@@ -40,6 +40,7 @@
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
frontend_choices = ClassChoices(
@@ -171,29 +172,30 @@
    optional=True,
)
class_choices_list = [
        # --frontend and --frontend_conf
        frontend_choices,
        # --specaug and --specaug_conf
        specaug_choices,
        # --normalize and --normalize_conf
        normalize_choices,
        # --model and --model_conf
        model_choices,
        # --encoder and --encoder_conf
        encoder_choices,
        # --decoder and --decoder_conf
        decoder_choices,
        # --predictor and --predictor_conf
        predictor_choices,
        # --encoder2 and --encoder2_conf
        encoder_choices2,
        # --decoder2 and --decoder2_conf
        decoder_choices2,
        # --predictor2 and --predictor2_conf
        predictor_choices2,
        # --stride_conv and --stride_conv_conf
        stride_conv_choices,
    ]
    # --frontend and --frontend_conf
    frontend_choices,
    # --specaug and --specaug_conf
    specaug_choices,
    # --normalize and --normalize_conf
    normalize_choices,
    # --model and --model_conf
    model_choices,
    # --encoder and --encoder_conf
    encoder_choices,
    # --decoder and --decoder_conf
    decoder_choices,
    # --predictor and --predictor_conf
    predictor_choices,
    # --encoder2 and --encoder2_conf
    encoder_choices2,
    # --decoder2 and --decoder2_conf
    decoder_choices2,
    # --predictor2 and --predictor2_conf
    predictor_choices2,
    # --stride_conv and --stride_conv_conf
    stride_conv_choices,
]
def build_asr_model(args):
    # token_list
@@ -270,6 +272,7 @@
        # predictor
        predictor_class = predictor_choices.get_class(args.predictor)
        predictor = predictor_class(**args.predictor_conf)
        model_class = model_choices.get_class(args.model)
        model = model_class(
            vocab_size=vocab_size,
@@ -283,4 +286,59 @@
            predictor=predictor,
            **args.model_conf,
        )
    elif
    elif args.model == "uniasr":
        # stride_conv
        stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
        stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
                                        odim=input_size + encoder.output_size())
        stride_conv_output_size = stride_conv.output_size()
        # encoder2
        encoder_class2 = encoder_choices2.get_class(args.encoder2)
        encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
        # decoder2
        decoder_class2 = decoder_choices2.get_class(args.decoder2)
        decoder2 = decoder_class2(
            vocab_size=vocab_size,
            encoder_output_size=encoder2.output_size(),
            **args.decoder2_conf,
        )
        # ctc2
        ctc2 = CTC(
            odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
        )
        # predictor
        predictor_class = predictor_choices.get_class(args.predictor)
        predictor = predictor_class(**args.predictor_conf)
        # predictor2
        predictor_class = predictor_choices2.get_class(args.predictor2)
        predictor2 = predictor_class(**args.predictor2_conf)
        model_class = model_choices.get_class(args.model)
        model = model_class(
            vocab_size=vocab_size,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            encoder=encoder,
            decoder=decoder,
            ctc=ctc,
            token_list=token_list,
            predictor=predictor,
            ctc2=ctc2,
            encoder2=encoder2,
            decoder2=decoder2,
            predictor2=predictor2,
            stride_conv=stride_conv,
            **args.model_conf,
        )
    else:
        raise NotImplementedError("Not supported model: {}".format(args.model))
    if args.init is not None:
        initialize(model, args.init)