| | |
| | | odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf |
| | | ) |
| | | |
| | | if args.model == "asr": |
| | | model |
| | | |
| | | |
| | | if args.model in ["asr", "mfcca"]: |
| | | model_class = model_choices.get_class(args.model) |
| | | model = model_class( |
| | | vocab_size=vocab_size, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | ctc=ctc, |
| | | token_list=token_list, |
| | | **args.model_conf, |
| | | ) |
| | | elif args.model == "paraformer": |
| | | # predictor |
| | | predictor_class = predictor_choices.get_class(args.predictor) |
| | | predictor = predictor_class(**args.predictor_conf) |
| | | model_class = model_choices.get_class(args.model) |
| | | model = model_class( |
| | | vocab_size=vocab_size, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | ctc=ctc, |
| | | token_list=token_list, |
| | | predictor=predictor, |
| | | **args.model_conf, |
| | | ) |
| | | elif |