aky15
2023-04-12 28a19dbc4e85d3b8a4ec2ef7483bba64d422b43f
funasr/tasks/asr_transducer.py
@@ -21,15 +21,13 @@
    LightweightConvolutionTransformerDecoder,
    TransformerDecoder,
)
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
from funasr.models_transducer.encoder.encoder import Encoder
from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
from funasr.models_transducer.joint_network import JointNetwork
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
from funasr.models.e2e_transducer import TransducerModel
from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
from funasr.models.joint_network import JointNetwork
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
@@ -75,7 +73,6 @@
        "encoder",
        classes=dict(
                encoder=Encoder,
                sanm_chunk_opt=SANMEncoderChunkOpt,
        ),
        default="encoder",
)
@@ -138,6 +135,12 @@
            help="Integer-string mapper for tokens.",
        )
        group.add_argument(
            "--split_with_space",
            type=str2bool,
            default=True,
            help="whether to split text using <space>",
        )
        group.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
@@ -152,7 +155,7 @@
        group.add_argument(
            "--model_conf",
            action=NestedDictAction,
            default=get_default_kwargs(ESPnetASRTransducerModel),
            default=get_default_kwargs(TransducerModel),
            help="The keyword arguments for the model class.",
        )
        # group.add_argument(
@@ -289,6 +292,7 @@
                non_linguistic_symbols=args.non_linguistic_symbols,
                text_cleaner=args.cleaner,
                g2p_type=args.g2p,
                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
                rir_apply_prob=args.rir_apply_prob
                if hasattr(args, "rir_apply_prob")
@@ -347,7 +351,7 @@
        return retval
    @classmethod
    def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
    def build_model(cls, args: argparse.Namespace) -> TransducerModel:
        """Required data depending on task mode.
        Args:
            cls: ASRTransducerTask object.
@@ -433,22 +437,8 @@
        # 7. Build model
        if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
            model = UniASRTransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
                specaug=specaug,
                normalize=normalize,
                encoder=encoder,
                decoder=decoder,
                att_decoder=att_decoder,
                joint_network=joint_network,
                **args.model_conf,
            )
        elif encoder.unified_model_training:
            model = ESPnetASRUnifiedTransducerModel(
        if encoder.unified_model_training:
            model = UnifiedTransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,
@@ -462,7 +452,7 @@
            )
        else:
            model = ESPnetASRTransducerModel(
            model = TransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,
                frontend=frontend,