嘉渊
2023-04-24 6427c834dfd97b1f05c6659cdc7ccf010bf82fe1
funasr/bin/train.py
@@ -19,6 +19,7 @@
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.prepare_data import prepare_data
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
@@ -285,12 +286,49 @@
        default=[],
    )
    parser.add_argument(
        "--train_shape_file",
        type=str, action="append",
        default=[],
    )
    parser.add_argument(
        "--valid_shape_file",
        type=str,
        action="append",
        default=[],
    )
    parser.add_argument(
        "--use_preprocessor",
        type=str2bool,
        default=True,
        help="Apply preprocessing to data or not",
    )
    # optimization related
    parser.add_argument(
        "--optim",
        type=lambda x: x.lower(),
        default="adam",
        help="The optimizer type",
    )
    parser.add_argument(
        "--optim_conf",
        action=NestedDictAction,
        default=dict(),
        help="The keyword arguments for optimizer",
    )
    parser.add_argument(
        "--scheduler",
        type=lambda x: str_or_none(x.lower()),
        default=None,
        help="The lr scheduler type",
    )
    parser.add_argument(
        "--scheduler_conf",
        action=NestedDictAction,
        default=dict(),
        help="The keyword arguments for lr scheduler",
    )
    # most task related
    parser.add_argument(
        "--init",