嘉渊
2023-04-24 87da7393041a9afaa42dc8c4b0fd3bcff624c182
funasr/bin/train.py
@@ -59,16 +59,19 @@
    )
    parser.add_argument(
        "--dist_world_size",
        type=int,
        default=1,
        help="number of nodes for distributed training",
    )
    parser.add_argument(
        "--dist_rank",
        type=int,
        default=None,
        help="node rank for distributed training",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=None,
        help="local rank for distributed training",
    )
@@ -85,6 +88,22 @@
        type=int_or_none,
        help="The master port for distributed training"
             "This value is used when dist_init_method == 'env://'",
    )
    parser.add_argument(
        "--dist_launcher",
        default=None,
        type=str_or_none,
        choices=["slurm", "mpi", None],
        help="The launcher type for distributed training",
    )
    parser.add_argument(
        "--multiprocessing_distributed",
        default=True,
        type=str2bool,
        help="Use multi-processing distributed training to launch "
             "N processes per node, which has N GPUs. This is the "
             "fastest way to use PyTorch for either single node or "
             "multi node data parallel training",
    )
    parser.add_argument(
        "--unused_parameters",
@@ -449,7 +468,6 @@
    args, extra_task_params = parser.parse_known_args()
    if extra_task_params:
        args = build_args(args, parser, extra_task_params)
        # args = argparse.Namespace(**vars(args), **vars(task_args))
    # set random seed
    set_all_random_seed(args.seed)