From d6cc6896e4d55498d6d36331b5c661579906525f Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 四月 2023 16:33:30 +0800
Subject: [PATCH] update

---
 funasr/bin/train.py        |   10 +++++
 funasr/utils/build_args.py |   87 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 97 insertions(+), 0 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 8acd37c..c6f19b6 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,3 +1,4 @@
+import argparse
 import logging
 import os
 import sys
@@ -9,6 +10,7 @@
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import config_argparse
+from funasr.utils.build_args import build_args
 from funasr.utils.build_dataloader import build_dataloader
 from funasr.utils.build_distributed import build_distributed
 from funasr.utils.build_model import build_model
@@ -272,6 +274,12 @@
         action="append",
         default=[],
     )
+    parser.add_argument(
+        "--use_preprocessor",
+        type=str2bool,
+        default=True,
+        help="Apply preprocessing to data or not",
+    )
 
     # pai related
     parser.add_argument(
@@ -330,6 +338,8 @@
 if __name__ == '__main__':
     parser = get_parser()
     args = parser.parse_args()
+    task_args = build_args(args)
+    args = argparse.Namespace(**vars(args), **vars(task_args))
 
     # set random seed
     set_all_random_seed(args.seed)
diff --git a/funasr/utils/build_args.py b/funasr/utils/build_args.py
new file mode 100644
index 0000000..1baf2d6
--- /dev/null
+++ b/funasr/utils/build_args.py
@@ -0,0 +1,87 @@
+import argparse
+
+from funasr.models.ctc import CTC
+from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+
+def build_args(args):
+    parser = argparse.ArgumentParser("Task related config")
+    if args.task_name == "asr":
+        from funasr.utils.build_asr_model import class_choices_list
+        for class_choices in class_choices_list:
+            # Append --<name> and --<name>_conf.
+            # e.g. --encoder and --encoder_conf
+            class_choices.add_arguments(parser)
+        parser.add_argument(
+            "--token_list",
+            type=str_or_none,
+            default=None,
+            help="A text mapping int-id to token",
+        )
+        parser.add_argument(
+            "--split_with_space",
+            type=str2bool,
+            default=True,
+            help="whether to split text using <space>",
+        )
+        parser.add_argument(
+            "--seg_dict_file",
+            type=str,
+            default=None,
+            help="seg_dict_file for text processing",
+        )
+        parser.add_argument(
+            "--init",
+            type=lambda x: str_or_none(x.lower()),
+            default=None,
+            help="The initialization method",
+            choices=[
+                "chainer",
+                "xavier_uniform",
+                "xavier_normal",
+                "kaiming_uniform",
+                "kaiming_normal",
+                None,
+            ],
+        )
+        parser.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+        parser.add_argument(
+            "--ctc_conf",
+            action=NestedDictAction,
+            default=get_default_kwargs(CTC),
+            help="The keyword arguments for CTC class.",
+        )
+        parser.add_argument(
+            "--token_type",
+            type=str,
+            default="bpe",
+            choices=["bpe", "char", "word", "phn"],
+            help="The text will be tokenized " "in the specified level token",
+        )
+        parser.add_argument(
+            "--bpemodel",
+            type=str_or_none,
+            default=None,
+            help="The model file of sentencepiece",
+        )
+        parser.add_argument(
+            "--cmvn_file",
+            type=str_or_none,
+            default=None,
+            help="The file path of noise scp file.",
+        )
+
+    else:
+        raise NotImplementedError("Not supported task: {}".format(args.task_name))
+
+    args = parser.parse_args()
+    return args

--
Gitblit v1.9.1