| | |
| | | d = ModelDownloader() |
| | | kwargs.update(**d.download_and_unpack(model_tag)) |
| | | |
| | | return Speech2Text(**kwargs) |
| | | return Speech2TextTransducer(**kwargs) |
| | | |
| | | |
| | | class Speech2TextSAASR: |
| | |
| | | contextual_paraformer=ContextualParaformer, |
| | | mfcca=MFCCA, |
| | | timestamp_prediction=TimestampPredictor, |
| | | rnnt=TransducerModel, |
| | | rnnt_unified=UnifiedTransducerModel, |
| | | ), |
| | | default="asr", |
| | | ) |
| | |
| | | token_list=token_list, |
| | | **args.model_conf, |
| | | ) |
| | | elif args.model == "rnnt": |
| | | elif args.model == "rnnt" or args.model == "rnnt_unified": |
| | | # 5. Decoder |
| | | encoder_output_size = encoder.output_size() |
| | | |
| | |
| | | **args.joint_network_conf, |
| | | ) |
| | | |
| | | model_class = model_choices.get_class(args.model) |
| | | # 7. Build model |
| | | if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training: |
| | | model = UnifiedTransducerModel( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | **args.model_conf, |
| | | ) |
| | | model = model_class( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | else: |
| | | model = TransducerModel( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | **args.model_conf, |
| | | ) |
| | | else: |
| | | raise NotImplementedError("Not supported model: {}".format(args.model)) |
| | | |
| | |
| | | neatcontextual_paraformer=NeatContextualParaformer, |
| | | mfcca=MFCCA, |
| | | timestamp_prediction=TimestampPredictor, |
| | | rnnt=TransducerModel, |
| | | rnnt_unified=UnifiedTransducerModel, |
| | | ), |
| | | type_check=FunASRModel, |
| | | default="asr", |
| | |
| | | decoder_output_size = decoder.output_size |
| | | |
| | | if getattr(args, "decoder", None) is not None: |
| | | att_decoder_class = decoder_choices.get_class(args.att_decoder) |
| | | att_decoder_class = decoder_choices.get_class(args.decoder) |
| | | |
| | | att_decoder = att_decoder_class( |
| | | vocab_size=vocab_size, |
| | |
| | | ) |
| | | |
| | | # 7. Build model |
| | | try: |
| | | model_class = model_choices.get_class(args.model) |
| | | except AttributeError: |
| | | model_class = model_choices.get_class("asr") |
| | | |
| | | if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training: |
| | | model = UnifiedTransducerModel( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | else: |
| | | model = TransducerModel( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | model = model_class( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | **args.model_conf, |
| | | ) |
| | | # 8. Initialize model |
| | | if args.init is not None: |
| | | raise NotImplementedError( |