嘉渊
2023-05-18 17eaf419c05853a4ecb8dfd3a0e8ebf26a1dfb1b
funasr/tasks/asr.py
@@ -132,6 +132,8 @@
        neatcontextual_paraformer=NeatContextualParaformer,
        mfcca=MFCCA,
        timestamp_prediction=TimestampPredictor,
        rnnt=TransducerModel,
        rnnt_unified=UnifiedTransducerModel,
    ),
    type_check=FunASRModel,
    default="asr",
@@ -1453,7 +1455,7 @@
        decoder_output_size = decoder.output_size
        if getattr(args, "decoder", None) is not None:
            att_decoder_class = decoder_choices.get_class(args.att_decoder)
            att_decoder_class = decoder_choices.get_class(args.decoder)
            att_decoder = att_decoder_class(
                vocab_size=vocab_size,
@@ -1471,35 +1473,23 @@
        )
        # 7. Build model
        try:
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("asr")
        if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
            model = UnifiedTransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
                specaug=specaug,
                normalize=normalize,
                encoder=encoder,
                decoder=decoder,
                att_decoder=att_decoder,
                joint_network=joint_network,
                **args.model_conf,
            )
        else:
            model = TransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
                specaug=specaug,
                normalize=normalize,
                encoder=encoder,
                decoder=decoder,
                att_decoder=att_decoder,
                joint_network=joint_network,
                **args.model_conf,
            )
        model = model_class(
            vocab_size=vocab_size,
            token_list=token_list,
            frontend=frontend,
            specaug=specaug,
            normalize=normalize,
            encoder=encoder,
            decoder=decoder,
            att_decoder=att_decoder,
            joint_network=joint_network,
            **args.model_conf,
        )
        # 8. Initialize model
        if args.init is not None:
            raise NotImplementedError(