From b9d1425028e480aa2c8dbd3502207e443dcd2060 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 25 四月 2023 01:09:03 +0800
Subject: [PATCH] update

---
 funasr/bin/train.py |  102 +++++++++++++++++++++++++++++++++++++++++++++++---
 1 files changed, 95 insertions(+), 7 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 1518071..d3ebaac 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -19,15 +19,17 @@
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
+from funasr.utils.nested_dict_action import NestedDictAction
 from funasr.utils.prepare_data import prepare_data
+from funasr.utils.types import int_or_none
 from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
 
 
 def get_parser():
-    parser = config_argparse.ArgumentParser(
+    parser = argparse.ArgumentParser(
         description="FunASR Common Training Parser",
     )
 
@@ -58,18 +60,51 @@
     )
     parser.add_argument(
         "--dist_world_size",
-        default=None,
+        type=int,
+        default=1,
         help="number of nodes for distributed training",
     )
     parser.add_argument(
         "--dist_rank",
+        type=int,
         default=None,
         help="node rank for distributed training",
     )
     parser.add_argument(
         "--local_rank",
+        type=int,
         default=None,
         help="local rank for distributed training",
+    )
+    parser.add_argument(
+        "--dist_master_addr",
+        default=None,
+        type=str_or_none,
+        help="The master address for distributed training. "
+             "This value is used when dist_init_method == 'env://'",
+    )
+    parser.add_argument(
+        "--dist_master_port",
+        default=None,
+        type=int_or_none,
+        help="The master port for distributed training"
+             "This value is used when dist_init_method == 'env://'",
+    )
+    parser.add_argument(
+        "--dist_launcher",
+        default=None,
+        type=str_or_none,
+        choices=["slurm", "mpi", None],
+        help="The launcher type for distributed training",
+    )
+    parser.add_argument(
+        "--multiprocessing_distributed",
+        default=True,
+        type=str2bool,
+        help="Use multi-processing distributed training to launch "
+             "N processes per node, which has N GPUs. This is the "
+             "fastest way to use PyTorch for either single node or "
+             "multi node data parallel training",
     )
     parser.add_argument(
         "--unused_parameters",
@@ -77,6 +112,12 @@
         default=False,
         help="Whether to use the find_unused_parameters in "
              "torch.nn.parallel.DistributedDataParallel ",
+    )
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
     )
 
     # cudnn related
@@ -257,6 +298,12 @@
         help="whether to use dataloader for large dataset",
     )
     parser.add_argument(
+        "--dataset_conf",
+        action=NestedDictAction,
+        default=dict(),
+        help=f"The keyword arguments for dataset",
+    )
+    parser.add_argument(
         "--train_data_file",
         type=str,
         default=None,
@@ -270,12 +317,26 @@
     )
     parser.add_argument(
         "--train_data_path_and_name_and_type",
+        type=str2triple_str,
         action="append",
         default=[],
         help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ",
     )
     parser.add_argument(
         "--valid_data_path_and_name_and_type",
+        type=str2triple_str,
+        action="append",
+        default=[],
+    )
+    parser.add_argument(
+        "--train_shape_file",
+        type=str,
+        action="append",
+        default=[],
+    )
+    parser.add_argument(
+        "--valid_shape_file",
+        type=str,
         action="append",
         default=[],
     )
@@ -284,6 +345,32 @@
         type=str2bool,
         default=True,
         help="Apply preprocessing to data or not",
+    )
+
+    # optimization related
+    parser.add_argument(
+        "--optim",
+        type=lambda x: x.lower(),
+        default="adam",
+        help="The optimizer type",
+    )
+    parser.add_argument(
+        "--optim_conf",
+        action=NestedDictAction,
+        default=dict(),
+        help="The keyword arguments for optimizer",
+    )
+    parser.add_argument(
+        "--scheduler",
+        type=lambda x: str_or_none(x.lower()),
+        default=None,
+        help="The lr scheduler type",
+    )
+    parser.add_argument(
+        "--scheduler_conf",
+        action=NestedDictAction,
+        default=dict(),
+        help="The keyword arguments for lr scheduler",
     )
 
     # most task related
@@ -388,9 +475,9 @@
 
 if __name__ == '__main__':
     parser = get_parser()
-    args = parser.parse_args()
-    task_args = build_args(args)
-    args = argparse.Namespace(**vars(args), **vars(task_args))
+    args, extra_task_params = parser.parse_known_args()
+    if extra_task_params:
+        args = build_args(args, parser, extra_task_params)
 
     # set random seed
     set_all_random_seed(args.seed)
@@ -399,7 +486,8 @@
     torch.backends.cudnn.deterministic = args.cudnn_deterministic
 
     # ddp init
-    args.distributed = args.dist_world_size > 1
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+    args.distributed = args.ngpu > 1 or args.dist_world_size > 1
     distributed_option = build_distributed(args)
 
     # for logging

--
Gitblit v1.9.1