嘉渊
2023-05-18 17eaf419c05853a4ecb8dfd3a0e8ebf26a1dfb1b
Merge branch 'dev_infer' of https://github.com/alibaba/FunASR into dev_infer
3个文件已修改
49 ■■■■ 已修改文件
funasr/bin/asr_infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/build_utils/build_asr_model.py 21 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py
@@ -1581,7 +1581,7 @@
            d = ModelDownloader()
            kwargs.update(**d.download_and_unpack(model_tag))
        
        return Speech2Text(**kwargs)
        return Speech2TextTransducer(**kwargs)
class Speech2TextSAASR:
funasr/build_utils/build_asr_model.py
@@ -87,6 +87,8 @@
        contextual_paraformer=ContextualParaformer,
        mfcca=MFCCA,
        timestamp_prediction=TimestampPredictor,
        rnnt=TransducerModel,
        rnnt_unified=UnifiedTransducerModel,
    ),
    default="asr",
)
@@ -367,7 +369,7 @@
            token_list=token_list,
            **args.model_conf,
        )
    elif args.model == "rnnt":
    elif args.model == "rnnt" or args.model == "rnnt_unified":
        # 5. Decoder
        encoder_output_size = encoder.output_size()
@@ -396,9 +398,9 @@
            **args.joint_network_conf,
        )
        model_class = model_choices.get_class(args.model)
        # 7. Build model
        if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
            model = UnifiedTransducerModel(
        model = model_class(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
@@ -411,19 +413,6 @@
                **args.model_conf,
            )
        else:
            model = TransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
                specaug=specaug,
                normalize=normalize,
                encoder=encoder,
                decoder=decoder,
                att_decoder=att_decoder,
                joint_network=joint_network,
                **args.model_conf,
            )
    else:
        raise NotImplementedError("Not supported model: {}".format(args.model))
funasr/tasks/asr.py
@@ -132,6 +132,8 @@
        neatcontextual_paraformer=NeatContextualParaformer,
        mfcca=MFCCA,
        timestamp_prediction=TimestampPredictor,
        rnnt=TransducerModel,
        rnnt_unified=UnifiedTransducerModel,
    ),
    type_check=FunASRModel,
    default="asr",
@@ -1453,7 +1455,7 @@
        decoder_output_size = decoder.output_size
        if getattr(args, "decoder", None) is not None:
            att_decoder_class = decoder_choices.get_class(args.att_decoder)
            att_decoder_class = decoder_choices.get_class(args.decoder)
            att_decoder = att_decoder_class(
                vocab_size=vocab_size,
@@ -1471,9 +1473,12 @@
        )
        # 7. Build model
        try:
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("asr")
        if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
            model = UnifiedTransducerModel(
        model = model_class(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
@@ -1485,21 +1490,6 @@
                joint_network=joint_network,
                **args.model_conf,
            )
        else:
            model = TransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
                specaug=specaug,
                normalize=normalize,
                encoder=encoder,
                decoder=decoder,
                att_decoder=att_decoder,
                joint_network=joint_network,
                **args.model_conf,
            )
        # 8. Initialize model
        if args.init is not None:
            raise NotImplementedError(