From eac9f111b502e4581b14dc718731bf7dc1c7d5f6 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 16:59:26 +0800
Subject: [PATCH] update

---
 funasr/utils/build_args.py |   57 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 56 insertions(+), 1 deletions(-)

diff --git a/funasr/utils/build_args.py b/funasr/utils/build_args.py
index 1baf2d6..f57f495 100644
--- a/funasr/utils/build_args.py
+++ b/funasr/utils/build_args.py
@@ -79,7 +79,62 @@
             default=None,
             help="The file path of noise scp file.",
         )
-
+    elif args.task_name == "pretrain":
+        from funasr.utils.build_pretrain_model import class_choices_list
+        for class_choices in class_choices_list:
+            # Append --<name> and --<name>_conf.
+            # e.g. --encoder and --encoder_conf
+            class_choices.add_arguments(parser)
+        parser.add_argument(
+            "--init",
+            type=lambda x: str_or_none(x.lower()),
+            default=None,
+            help="The initialization method",
+            choices=[
+                "chainer",
+                "xavier_uniform",
+                "xavier_normal",
+                "kaiming_uniform",
+                "kaiming_normal",
+                None,
+            ],
+        )
+        parser.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+        parser.add_argument(
+            "--feats_type",
+            type=str,
+            default='fbank',
+            help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
+        )
+        parser.add_argument(
+            "--noise_db_range",
+            type=str,
+            default="13_15",
+            help="The range of noise decibel level.",
+        )
+        parser.add_argument(
+            "--pred_masked_weight",
+            type=float,
+            default=1.0,
+            help="weight for predictive loss for masked frames",
+        )
+        parser.add_argument(
+            "--pred_nomask_weight",
+            type=float,
+            default=0.0,
+            help="weight for predictive loss for unmasked frames",
+        )
+        parser.add_argument(
+            "--loss_weights",
+            type=float,
+            default=0.0,
+            help="weights for additional loss terms (not first one)",
+        )
     else:
         raise NotImplementedError("Not supported task: {}".format(args.task_name))
 

--
Gitblit v1.9.1