| | |
| | | from funasr.models.e2e_asr_mfcca import MFCCA |
| | | |
| | | from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel |
| | | from funasr.models.e2e_asr_bat import BATModel |
| | | |
| | | from funasr.models.e2e_sa_asr import SAASRModel |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer |
| | |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.frontend.windowing import SlidingWindow |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3 |
| | | from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.models.specaug.specaug import SpecAugLFR |
| | | from funasr.modules.subsampling import Conv1dSubsampling |
| | |
| | | rnnt=TransducerModel, |
| | | rnnt_unified=UnifiedTransducerModel, |
| | | sa_asr=SAASRModel, |
| | | |
| | | bat=BATModel, |
| | | ), |
| | | default="asr", |
| | | ) |
| | |
| | | ctc_predictor=None, |
| | | cif_predictor_v2=CifPredictorV2, |
| | | cif_predictor_v3=CifPredictorV3, |
| | | bat_predictor=BATPredictor, |
| | | ), |
| | | default="cif_predictor", |
| | | optional=True, |
| | |
| | | encoder = encoder_class(input_size=input_size, **args.encoder_conf) |
| | | |
| | | # decoder |
| | | decoder_class = decoder_choices.get_class(args.decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder.output_size(), |
| | | **args.decoder_conf, |
| | | ) |
| | | if hasattr(args, "decoder") and args.decoder is not None: |
| | | decoder_class = decoder_choices.get_class(args.decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder.output_size(), |
| | | **args.decoder_conf, |
| | | ) |
| | | else: |
| | | decoder = None |
| | | |
| | | # ctc |
| | | ctc = CTC( |
| | |
| | | joint_network=joint_network, |
| | | **args.model_conf, |
| | | ) |
| | | elif args.model == "bat": |
| | | # 5. Decoder |
| | | encoder_output_size = encoder.output_size() |
| | | |
| | | rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder) |
| | | decoder = rnnt_decoder_class( |
| | | vocab_size, |
| | | **args.rnnt_decoder_conf, |
| | | ) |
| | | decoder_output_size = decoder.output_size |
| | | |
| | | if getattr(args, "decoder", None) is not None: |
| | | att_decoder_class = decoder_choices.get_class(args.decoder) |
| | | |
| | | att_decoder = att_decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | **args.decoder_conf, |
| | | ) |
| | | else: |
| | | att_decoder = None |
| | | # 6. Joint Network |
| | | joint_network = JointNetwork( |
| | | vocab_size, |
| | | encoder_output_size, |
| | | decoder_output_size, |
| | | **args.joint_network_conf, |
| | | ) |
| | | |
| | | predictor_class = predictor_choices.get_class(args.predictor) |
| | | predictor = predictor_class(**args.predictor_conf) |
| | | |
| | | model_class = model_choices.get_class(args.model) |
| | | # 7. Build model |
| | | model = model_class( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | predictor=predictor, |
| | | **args.model_conf, |
| | | ) |
| | | elif args.model == "sa_asr": |
| | | asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder) |
| | | asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf) |