嘉渊
2023-04-25 4b7dff9c7147c8ab8b66dedceee3d2b8ee485f10
funasr/bin/train.py
@@ -21,6 +21,7 @@
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.prepare_data import prepare_data
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
@@ -58,18 +59,51 @@
    )
    parser.add_argument(
        "--dist_world_size",
        default=None,
        type=int,
        default=1,
        help="number of nodes for distributed training",
    )
    parser.add_argument(
        "--dist_rank",
        type=int,
        default=None,
        help="node rank for distributed training",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=None,
        help="local rank for distributed training",
    )
    parser.add_argument(
        "--dist_master_addr",
        default=None,
        type=str_or_none,
        help="The master address for distributed training. "
             "This value is used when dist_init_method == 'env://'",
    )
    parser.add_argument(
        "--dist_master_port",
        default=None,
        type=int_or_none,
        help="The master port for distributed training"
             "This value is used when dist_init_method == 'env://'",
    )
    parser.add_argument(
        "--dist_launcher",
        default=None,
        type=str_or_none,
        choices=["slurm", "mpi", None],
        help="The launcher type for distributed training",
    )
    parser.add_argument(
        "--multiprocessing_distributed",
        default=True,
        type=str2bool,
        help="Use multi-processing distributed training to launch "
             "N processes per node, which has N GPUs. This is the "
             "fastest way to use PyTorch for either single node or "
             "multi node data parallel training",
    )
    parser.add_argument(
        "--unused_parameters",
@@ -207,6 +241,12 @@
        help="Enable resuming if checkpoint is existing",
    )
    parser.add_argument(
        "--train_dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type for training.",
    )
    parser.add_argument(
        "--use_amp",
        type=str2bool,
        default=False,
@@ -218,6 +258,12 @@
        help="Show the logs every the number iterations in each epochs at the "
             "training phase. If None is given, it is decided according the number "
             "of training samples automatically .",
    )
    parser.add_argument(
        "--use_tensorboard",
        type=str2bool,
        default=True,
        help="Enable tensorboard logging",
    )
    # pretrained model related
@@ -263,39 +309,30 @@
        help="whether to use dataloader for large dataset",
    )
    parser.add_argument(
        "--train_data_file",
        "--dataset_conf",
        action=NestedDictAction,
        default=dict(),
        help=f"The keyword arguments for dataset",
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        default=None,
        help="train_list for large dataset",
        help="root path of data",
    )
    parser.add_argument(
        "--valid_data_file",
        "--train_set",
        type=str,
        default=None,
        help="valid_list for large dataset",
        default="train",
        help="train dataset",
    )
    parser.add_argument(
        "--train_data_path_and_name_and_type",
        action="append",
        default=[],
        help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ",
    )
    parser.add_argument(
        "--valid_data_path_and_name_and_type",
        action="append",
        default=[],
    )
    parser.add_argument(
        "--train_shape_file",
        type=str, action="append",
        default=[],
    )
    parser.add_argument(
        "--valid_shape_file",
        "--valid_set",
        type=str,
        action="append",
        default=[],
        default="validation",
        help="dev dataset",
    )
    parser.add_argument(
        "--use_preprocessor",
        type=str2bool,
@@ -434,7 +471,6 @@
    args, extra_task_params = parser.parse_known_args()
    if extra_task_params:
        args = build_args(args, parser, extra_task_params)
        # args = argparse.Namespace(**vars(args), **vars(task_args))
    # set random seed
    set_all_random_seed(args.seed)
@@ -444,7 +480,7 @@
    # ddp init
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    args.distributed = args.dist_world_size > 1
    args.distributed = args.ngpu > 1 or args.dist_world_size > 1
    distributed_option = build_distributed(args)
    # for logging
@@ -465,6 +501,10 @@
    prepare_data(args, distributed_option)
    model = build_model(args)
    model = model.to(
        dtype=getattr(torch, args.train_dtype),
        device="cuda" if args.ngpu > 0 else "cpu",
    )
    optimizers = build_optimizer(args, model=model)
    schedulers = build_scheduler(args, optimizers)
@@ -472,6 +512,7 @@
                                                                   distributed_option.dist_rank,
                                                                   distributed_option.local_rank))
    logging.info(pytorch_cudnn_version())
    logging.info("Args: {}".format(args))
    logging.info(model_summary(model))
    logging.info("Optimizer: {}".format(optimizers))
    logging.info("Scheduler: {}".format(schedulers))