嘉渊
2023-04-25 b9d1425028e480aa2c8dbd3502207e443dcd2060
funasr/bin/train.py
@@ -23,6 +23,7 @@
from funasr.utils.prepare_data import prepare_data
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
@@ -59,16 +60,19 @@
    )
    parser.add_argument(
        "--dist_world_size",
        type=int,
        default=1,
        help="number of nodes for distributed training",
    )
    parser.add_argument(
        "--dist_rank",
        type=int,
        default=None,
        help="node rank for distributed training",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=None,
        help="local rank for distributed training",
    )
@@ -294,6 +298,12 @@
        help="whether to use dataloader for large dataset",
    )
    parser.add_argument(
        "--dataset_conf",
        action=NestedDictAction,
        default=dict(),
        help=f"The keyword arguments for dataset",
    )
    parser.add_argument(
        "--train_data_file",
        type=str,
        default=None,
@@ -307,18 +317,21 @@
    )
    parser.add_argument(
        "--train_data_path_and_name_and_type",
        type=str2triple_str,
        action="append",
        default=[],
        help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ",
    )
    parser.add_argument(
        "--valid_data_path_and_name_and_type",
        type=str2triple_str,
        action="append",
        default=[],
    )
    parser.add_argument(
        "--train_shape_file",
        type=str, action="append",
        type=str,
        action="append",
        default=[],
    )
    parser.add_argument(
@@ -462,9 +475,9 @@
if __name__ == '__main__':
    parser = get_parser()
    common_args, extra_task_params = parser.parse_known_args()
    args, extra_task_params = parser.parse_known_args()
    if extra_task_params:
        args = build_args(common_args, parser, extra_task_params)
        args = build_args(args, parser, extra_task_params)
    # set random seed
    set_all_random_seed(args.seed)