| | |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | | from funasr.modules.subsampling import Conv1dSubsampling |
| | | from funasr.torch_utils.initialize import initialize |
| | | from funasr.train.class_choices import ClassChoices |
| | | |
| | | frontend_choices = ClassChoices( |
| | |
| | | optional=True, |
| | | ) |
| | | class_choices_list = [ |
| | | # --frontend and --frontend_conf |
| | | frontend_choices, |
| | | # --specaug and --specaug_conf |
| | | specaug_choices, |
| | | # --normalize and --normalize_conf |
| | | normalize_choices, |
| | | # --model and --model_conf |
| | | model_choices, |
| | | # --encoder and --encoder_conf |
| | | encoder_choices, |
| | | # --decoder and --decoder_conf |
| | | decoder_choices, |
| | | # --predictor and --predictor_conf |
| | | predictor_choices, |
| | | # --encoder2 and --encoder2_conf |
| | | encoder_choices2, |
| | | # --decoder2 and --decoder2_conf |
| | | decoder_choices2, |
| | | # --predictor2 and --predictor2_conf |
| | | predictor_choices2, |
| | | # --stride_conv and --stride_conv_conf |
| | | stride_conv_choices, |
| | | ] |
| | | # --frontend and --frontend_conf |
| | | frontend_choices, |
| | | # --specaug and --specaug_conf |
| | | specaug_choices, |
| | | # --normalize and --normalize_conf |
| | | normalize_choices, |
| | | # --model and --model_conf |
| | | model_choices, |
| | | # --encoder and --encoder_conf |
| | | encoder_choices, |
| | | # --decoder and --decoder_conf |
| | | decoder_choices, |
| | | # --predictor and --predictor_conf |
| | | predictor_choices, |
| | | # --encoder2 and --encoder2_conf |
| | | encoder_choices2, |
| | | # --decoder2 and --decoder2_conf |
| | | decoder_choices2, |
| | | # --predictor2 and --predictor2_conf |
| | | predictor_choices2, |
| | | # --stride_conv and --stride_conv_conf |
| | | stride_conv_choices, |
| | | ] |
| | | |
| | | |
| | | def build_asr_model(args): |
| | | # token_list |
| | |
| | | # predictor |
| | | predictor_class = predictor_choices.get_class(args.predictor) |
| | | predictor = predictor_class(**args.predictor_conf) |
| | | |
| | | model_class = model_choices.get_class(args.model) |
| | | model = model_class( |
| | | vocab_size=vocab_size, |
| | |
| | | predictor=predictor, |
| | | **args.model_conf, |
| | | ) |
| | | elif |
| | | elif args.model == "uniasr": |
| | | # stride_conv |
| | | stride_conv_class = stride_conv_choices.get_class(args.stride_conv) |
| | | stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(), |
| | | odim=input_size + encoder.output_size()) |
| | | stride_conv_output_size = stride_conv.output_size() |
| | | |
| | | # encoder2 |
| | | encoder_class2 = encoder_choices2.get_class(args.encoder2) |
| | | encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf) |
| | | |
| | | # decoder2 |
| | | decoder_class2 = decoder_choices2.get_class(args.decoder2) |
| | | decoder2 = decoder_class2( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder2.output_size(), |
| | | **args.decoder2_conf, |
| | | ) |
| | | |
| | | # ctc2 |
| | | ctc2 = CTC( |
| | | odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf |
| | | ) |
| | | |
| | | # predictor |
| | | predictor_class = predictor_choices.get_class(args.predictor) |
| | | predictor = predictor_class(**args.predictor_conf) |
| | | |
| | | # predictor2 |
| | | predictor_class = predictor_choices2.get_class(args.predictor2) |
| | | predictor2 = predictor_class(**args.predictor2_conf) |
| | | |
| | | model_class = model_choices.get_class(args.model) |
| | | model = model_class( |
| | | vocab_size=vocab_size, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | ctc=ctc, |
| | | token_list=token_list, |
| | | predictor=predictor, |
| | | ctc2=ctc2, |
| | | encoder2=encoder2, |
| | | decoder2=decoder2, |
| | | predictor2=predictor2, |
| | | stride_conv=stride_conv, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | else: |
| | | raise NotImplementedError("Not supported model: {}".format(args.model)) |
| | | |
| | | if args.init is not None: |
| | | initialize(model, args.init) |