| | |
| | | from funasr.utils.prepare_data import prepare_data |
| | | from funasr.utils.types import int_or_none |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
| | | |
| | |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_rank", |
| | | type=int, |
| | | default=None, |
| | | help="node rank for distributed training", |
| | | ) |
| | |
| | | help="Enable resuming if checkpoint is existing", |
| | | ) |
| | | parser.add_argument( |
| | | "--train_dtype", |
| | | default="float32", |
| | | choices=["float16", "float32", "float64"], |
| | | help="Data type for training.", |
| | | ) |
| | | parser.add_argument( |
| | | "--use_amp", |
| | | type=str2bool, |
| | | default=False, |
| | |
| | | help="Show the logs every the number iterations in each epochs at the " |
| | | "training phase. If None is given, it is decided according the number " |
| | | "of training samples automatically .", |
| | | ) |
| | | parser.add_argument( |
| | | "--use_tensorboard", |
| | | type=str2bool, |
| | | default=True, |
| | | help="Enable tensorboard logging", |
| | | ) |
| | | |
| | | # pretrained model related |
| | |
| | | help="whether to use dataloader for large dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--dataset_conf", |
| | | action=NestedDictAction, |
| | | default=dict(), |
| | | help=f"The keyword arguments for dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--train_data_file", |
| | | type=str, |
| | | default=None, |
| | |
| | | ) |
| | | parser.add_argument( |
| | | "--train_data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | action="append", |
| | | default=[], |
| | | help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ", |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | action="append", |
| | | default=[], |
| | | ) |
| | | parser.add_argument( |
| | | "--train_shape_file", |
| | | type=str, action="append", |
| | | type=str, |
| | | action="append", |
| | | default=[], |
| | | ) |
| | | parser.add_argument( |
| | |
| | | distributed_option.dist_rank, |
| | | distributed_option.local_rank)) |
| | | logging.info(pytorch_cudnn_version()) |
| | | logging.info("Args: {}".format(args)) |
| | | logging.info(model_summary(model)) |
| | | logging.info("Optimizer: {}".format(optimizers)) |
| | | logging.info("Scheduler: {}".format(schedulers)) |