嘉渊
2023-04-24 df5f263e5fe3d7961b1aeb3589012400a9905a8f
funasr/build_utils/build_args.py
@@ -13,15 +13,7 @@
    if args.task_name == "asr":
        from funasr.build_utils.build_asr_model import class_choices_list
        for class_choices in class_choices_list:
            # Append --<name> and --<name>_conf.
            # e.g. --encoder and --encoder_conf
            class_choices.add_arguments(parser)
        parser.add_argument(
            "--token_list",
            type=str_or_none,
            default=None,
            help="A text mapping int-id to token",
        )
        parser.add_argument(
            "--split_with_space",
            type=str2bool,
@@ -33,20 +25,6 @@
            type=str,
            default=None,
            help="seg_dict_file for text processing",
        )
        parser.add_argument(
            "--init",
            type=lambda x: str_or_none(x.lower()),
            default=None,
            help="The initialization method",
            choices=[
                "chainer",
                "xavier_uniform",
                "xavier_normal",
                "kaiming_uniform",
                "kaiming_normal",
                None,
            ],
        )
        parser.add_argument(
            "--input_size",
@@ -61,80 +39,49 @@
            help="The keyword arguments for CTC class.",
        )
        parser.add_argument(
            "--token_type",
            type=str,
            default="bpe",
            choices=["bpe", "char", "word", "phn"],
            help="The text will be tokenized " "in the specified level token",
        )
        parser.add_argument(
            "--bpemodel",
            type=str_or_none,
            default=None,
            help="The model file of sentencepiece",
        )
        parser.add_argument(
            "--cmvn_file",
            type=str_or_none,
            default=None,
            help="The file path of noise scp file.",
        )
    elif args.task_name == "pretrain":
        from funasr.build_utils.build_pretrain_model import class_choices_list
        for class_choices in class_choices_list:
            # Append --<name> and --<name>_conf.
            # e.g. --encoder and --encoder_conf
            class_choices.add_arguments(parser)
        parser.add_argument(
            "--init",
            type=lambda x: str_or_none(x.lower()),
            default=None,
            help="The initialization method",
            choices=[
                "chainer",
                "xavier_uniform",
                "xavier_normal",
                "kaiming_uniform",
                "kaiming_normal",
                None,
            ],
        )
        parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
    elif args.task_name == "lm":
        from funasr.build_utils.build_lm_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(parser)
    elif args.task_name == "punc":
        from funasr.build_utils.build_punc_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(parser)
    elif args.task_name == "vad":
        from funasr.build_utils.build_vad_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(parser)
        parser.add_argument(
            "--feats_type",
            type=str,
            default='fbank',
            help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
        parser.add_argument(
            "--noise_db_range",
            type=str,
            default="13_15",
            help="The range of noise decibel level.",
        )
        parser.add_argument(
            "--pred_masked_weight",
            type=float,
            default=1.0,
            help="weight for predictive loss for masked frames",
        )
        parser.add_argument(
            "--pred_nomask_weight",
            type=float,
            default=0.0,
            help="weight for predictive loss for unmasked frames",
        )
        parser.add_argument(
            "--loss_weights",
            type=float,
            default=0.0,
            help="weights for additional loss terms (not first one)",
        )
    elif args.task_name == "diar":
        from funasr.build_utils.build_diar_model import class_choices_list
        for class_choices in class_choices_list:
            class_choices.add_arguments(parser)
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))