嘉渊
2023-04-24 996b951365905df0314b0b611d0176ae6df5d178
funasr/bin/train.py
@@ -19,7 +19,9 @@
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.prepare_data import prepare_data
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
@@ -57,7 +59,7 @@
    )
    parser.add_argument(
        "--dist_world_size",
        default=None,
        default=1,
        help="number of nodes for distributed training",
    )
    parser.add_argument(
@@ -69,6 +71,20 @@
        "--local_rank",
        default=None,
        help="local rank for distributed training",
    )
    parser.add_argument(
        "--dist_master_addr",
        default=None,
        type=str_or_none,
        help="The master address for distributed training. "
             "This value is used when dist_init_method == 'env://'",
    )
    parser.add_argument(
        "--dist_master_port",
        default=None,
        type=int_or_none,
        help="The master port for distributed training"
             "This value is used when dist_init_method == 'env://'",
    )
    parser.add_argument(
        "--unused_parameters",
@@ -302,6 +318,32 @@
        help="Apply preprocessing to data or not",
    )
    # optimization related
    parser.add_argument(
        "--optim",
        type=lambda x: x.lower(),
        default="adam",
        help="The optimizer type",
    )
    parser.add_argument(
        "--optim_conf",
        action=NestedDictAction,
        default=dict(),
        help="The keyword arguments for optimizer",
    )
    parser.add_argument(
        "--scheduler",
        type=lambda x: str_or_none(x.lower()),
        default=None,
        help="The lr scheduler type",
    )
    parser.add_argument(
        "--scheduler_conf",
        action=NestedDictAction,
        default=dict(),
        help="The keyword arguments for lr scheduler",
    )
    # most task related
    parser.add_argument(
        "--init",
@@ -417,7 +459,7 @@
    # ddp init
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    args.distributed = args.dist_world_size > 1
    args.distributed = args.ngpu > 1 or args.dist_world_size > 1
    distributed_option = build_distributed(args)
    # for logging