speech_asr
2023-04-20 eac9f111b502e4581b14dc718731bf7dc1c7d5f6
funasr/utils/build_args.py
@@ -79,7 +79,62 @@
            default=None,
            help="The file path of noise scp file.",
        )
    elif args.task_name == "pretrain":
        from funasr.utils.build_pretrain_model import class_choices_list
        for class_choices in class_choices_list:
            # Append --<name> and --<name>_conf.
            # e.g. --encoder and --encoder_conf
            class_choices.add_arguments(parser)
        parser.add_argument(
            "--init",
            type=lambda x: str_or_none(x.lower()),
            default=None,
            help="The initialization method",
            choices=[
                "chainer",
                "xavier_uniform",
                "xavier_normal",
                "kaiming_uniform",
                "kaiming_normal",
                None,
            ],
        )
        parser.add_argument(
            "--input_size",
            type=int_or_none,
            default=None,
            help="The number of input dimension of the feature",
        )
        parser.add_argument(
            "--feats_type",
            type=str,
            default='fbank',
            help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
        )
        parser.add_argument(
            "--noise_db_range",
            type=str,
            default="13_15",
            help="The range of noise decibel level.",
        )
        parser.add_argument(
            "--pred_masked_weight",
            type=float,
            default=1.0,
            help="weight for predictive loss for masked frames",
        )
        parser.add_argument(
            "--pred_nomask_weight",
            type=float,
            default=0.0,
            help="weight for predictive loss for unmasked frames",
        )
        parser.add_argument(
            "--loss_weights",
            type=float,
            default=0.0,
            help="weights for additional loss terms (not first one)",
        )
    else:
        raise NotImplementedError("Not supported task: {}".format(args.task_name))