| New file |
| | |
| | | import sys |
| | | |
| | | import torch |
| | | |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.build_distributed import build_distributed |
| | | from funasr.utils.types import str2bool |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | description="FunASR Common Training Parser", |
| | | ) |
| | | |
| | | # common configuration |
| | | parser.add_argument("--output_dir", help="model save path") |
| | | parser.add_argument( |
| | | "--ngpu", |
| | | type=int, |
| | | default=0, |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | |
| | | # ddp related |
| | | parser.add_argument( |
| | | "--dist_backend", |
| | | default="nccl", |
| | | type=str, |
| | | help="distributed backend", |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_init_method", |
| | | type=str, |
| | | default="env://", |
| | | help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", ' |
| | | '"WORLD_SIZE", and "RANK" are referred.', |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_world_size", |
| | | default=None, |
| | | help="number of nodes for distributed training", |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_rank", |
| | | default=None, |
| | | help="node rank for distributed training", |
| | | ) |
| | | parser.add_argument( |
| | | "--local_rank", |
| | | default=None, |
| | | help="local rank for distributed training", |
| | | ) |
| | | parser.add_argument( |
| | | "--unused_parameters", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Whether to use the find_unused_parameters in " |
| | | "torch.nn.parallel.DistributedDataParallel ", |
| | | ) |
| | | |
| | | # cudnn related |
| | | parser.add_argument( |
| | | "--cudnn_enabled", |
| | | type=str2bool, |
| | | default=torch.backends.cudnn.enabled, |
| | | help="Enable CUDNN", |
| | | ) |
| | | parser.add_argument( |
| | | "--cudnn_benchmark", |
| | | type=str2bool, |
| | | default=torch.backends.cudnn.benchmark, |
| | | help="Enable cudnn-benchmark mode", |
| | | ) |
| | | parser.add_argument( |
| | | "--cudnn_deterministic", |
| | | type=str2bool, |
| | | default=True, |
| | | help="Enable cudnn-deterministic mode", |
| | | ) |
| | | |
| | | # trainer related |
| | | parser.add_argument( |
| | | "--max_epoch", |
| | | type=int, |
| | | default=40, |
| | | help="The maximum number epoch to train", |
| | | ) |
| | | parser.add_argument( |
| | | "--max_update", |
| | | type=int, |
| | | default=sys.maxsize, |
| | | help="The maximum number update step to train", |
| | | ) |
| | | parser.add_argument( |
| | | "--batch_interval", |
| | | type=int, |
| | | default=10000, |
| | | help="The batch interval for saving model.", |
| | | ) |
| | | parser.add_argument( |
| | | "--patience", |
| | | default=None, |
| | | help="Number of epochs to wait without improvement " |
| | | "before stopping the training", |
| | | ) |
| | | parser.add_argument( |
| | | "--val_scheduler_criterion", |
| | | type=str, |
| | | nargs=2, |
| | | default=("valid", "loss"), |
| | | help="The criterion used for the value given to the lr scheduler. " |
| | | 'Give a pair referring the phase, "train" or "valid",' |
| | | 'and the criterion name. The mode specifying "min" or "max" can ' |
| | | "be changed by --scheduler_conf", |
| | | ) |
| | | parser.add_argument( |
| | | "--early_stopping_criterion", |
| | | type=str, |
| | | nargs=3, |
| | | default=("valid", "loss", "min"), |
| | | help="The criterion used for judging of early stopping. " |
| | | 'Give a pair referring the phase, "train" or "valid",' |
| | | 'the criterion name and the mode, "min" or "max", e.g. "acc,max".', |
| | | ) |
| | | parser.add_argument( |
| | | "--best_model_criterion", |
| | | nargs="+", |
| | | default=[ |
| | | ("train", "loss", "min"), |
| | | ("valid", "loss", "min"), |
| | | ("train", "acc", "max"), |
| | | ("valid", "acc", "max"), |
| | | ], |
| | | help="The criterion used for judging of the best model. " |
| | | 'Give a pair referring the phase, "train" or "valid",' |
| | | 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".', |
| | | ) |
| | | parser.add_argument( |
| | | "--keep_nbest_models", |
| | | type=int, |
| | | nargs="+", |
| | | default=[10], |
| | | help="Remove previous snapshots excluding the n-best scored epochs", |
| | | ) |
| | | parser.add_argument( |
| | | "--nbest_averaging_interval", |
| | | type=int, |
| | | default=0, |
| | | help="The epoch interval to apply model averaging and save nbest models", |
| | | ) |
| | | parser.add_argument( |
| | | "--grad_clip", |
| | | type=float, |
| | | default=5.0, |
| | | help="Gradient norm threshold to clip", |
| | | ) |
| | | parser.add_argument( |
| | | "--grad_clip_type", |
| | | type=float, |
| | | default=2.0, |
| | | help="The type of the used p-norm for gradient clip. Can be inf", |
| | | ) |
| | | parser.add_argument( |
| | | "--grad_noise", |
| | | type=str2bool, |
| | | default=False, |
| | | help="The flag to switch to use noise injection to " |
| | | "gradients during training", |
| | | ) |
| | | parser.add_argument( |
| | | "--accum_grad", |
| | | type=int, |
| | | default=1, |
| | | help="The number of gradient accumulation", |
| | | ) |
| | | parser.add_argument( |
| | | "--resume", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Enable resuming if checkpoint is existing", |
| | | ) |
| | | parser.add_argument( |
| | | "--use_amp", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6", |
| | | ) |
| | | parser.add_argument( |
| | | "--log_interval", |
| | | default=None, |
| | | help="Show the logs every the number iterations in each epochs at the " |
| | | "training phase. If None is given, it is decided according the number " |
| | | "of training samples automatically .", |
| | | ) |
| | | |
| | | # pretrained model related |
| | | parser.add_argument( |
| | | "--init_param", |
| | | type=str, |
| | | default=[], |
| | | nargs="*", |
| | | help="Specify the file path used for initialization of parameters. " |
| | | "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', " |
| | | "where file_path is the model file path, " |
| | | "src_key specifies the key of model states to be used in the model file, " |
| | | "dst_key specifies the attribute of the model to be initialized, " |
| | | "and exclude_keys excludes keys of model states for the initialization." |
| | | "e.g.\n" |
| | | " # Load all parameters" |
| | | " --init_param some/where/model.pb\n" |
| | | " # Load only decoder parameters" |
| | | " --init_param some/where/model.pb:decoder:decoder\n" |
| | | " # Load only decoder parameters excluding decoder.embed" |
| | | " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n" |
| | | " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n", |
| | | ) |
| | | parser.add_argument( |
| | | "--ignore_init_mismatch", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Ignore size mismatch when loading pre-trained model", |
| | | ) |
| | | parser.add_argument( |
| | | "--freeze_param", |
| | | type=str, |
| | | default=[], |
| | | nargs="*", |
| | | help="Freeze parameters", |
| | | ) |
| | | |
| | | # dataset related |
| | | parser.add_argument( |
| | | "--dataset_type", |
| | | type=str, |
| | | default="small", |
| | | help="whether to use dataloader for large dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--train_data_file", |
| | | type=str, |
| | | default=None, |
| | | help="train_list for large dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_data_file", |
| | | type=str, |
| | | default=None, |
| | | help="valid_list for large dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--train_data_path_and_name_and_type", |
| | | action="append", |
| | | default=[], |
| | | help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ", |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_data_path_and_name_and_type", |
| | | action="append", |
| | | default=[], |
| | | ) |
| | | |
| | | # pai related |
| | | parser.add_argument( |
| | | "--use_pai", |
| | | type=str2bool, |
| | | default=False, |
| | | help="flag to indicate whether training on PAI", |
| | | ) |
| | | parser.add_argument( |
| | | "--simple_ddp", |
| | | type=str2bool, |
| | | default=False, |
| | | ) |
| | | parser.add_argument( |
| | | "--num_worker_count", |
| | | type=int, |
| | | default=1, |
| | | help="The number of machines on PAI.", |
| | | ) |
| | | parser.add_argument( |
| | | "--access_key_id", |
| | | type=str, |
| | | default=None, |
| | | help="The username for oss.", |
| | | ) |
| | | parser.add_argument( |
| | | "--access_key_secret", |
| | | type=str, |
| | | default=None, |
| | | help="The password for oss.", |
| | | ) |
| | | parser.add_argument( |
| | | "--endpoint", |
| | | type=str, |
| | | default=None, |
| | | help="The endpoint for oss.", |
| | | ) |
| | | parser.add_argument( |
| | | "--bucket_name", |
| | | type=str, |
| | | default=None, |
| | | help="The bucket name for oss.", |
| | | ) |
| | | parser.add_argument( |
| | | "--oss_bucket", |
| | | default=None, |
| | | help="oss bucket.", |
| | | ) |
| | | |
| | | # task related |
| | | parser.add_argument("--task_name", help="for different task") |
| | | |
| | | return parser |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | parser = get_parser() |
| | | args = parser.parse_args() |
| | | |
| | | args.distributed = args.dist_world_size > 1 |
| | | distributed_option = build_distributed(args) |
| | | |
| | | # |
| | | |
| | | |
| | |
| | | import torch.nn |
| | | import torch.optim |
| | | import yaml |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from torch.utils.data import DataLoader |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | |
| | | from funasr.iterators.multiple_iter_factory import MultipleIterFactory |
| | | from funasr.iterators.sequence_iter_factory import SequenceIterFactory |
| | | from funasr.main_funcs.collect_stats import collect_stats |
| | | from funasr.optimizers.sgd import SGD |
| | | from funasr.optimizers.fairseq_adam import FairseqAdam |
| | | from funasr.optimizers.sgd import SGD |
| | | from funasr.samplers.build_batch_sampler import BATCH_TYPES |
| | | from funasr.samplers.build_batch_sampler import build_batch_sampler |
| | | from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler |
| | | from funasr.schedulers.noam_lr import NoamLR |
| | | from funasr.schedulers.warmup_lr import WarmupLR |
| | | from funasr.schedulers.tri_stage_scheduler import TriStageLR |
| | | from funasr.schedulers.warmup_lr import WarmupLR |
| | | from funasr.torch_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.torch_utils.model_summary import model_summary |
| | | from funasr.torch_utils.pytorch_version import pytorch_cudnn_version |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.train.class_choices import ClassChoices |
| | | from funasr.train.distributed_utils import DistributedOption |
| | | from funasr.train.trainer import Trainer |
| New file |
| | |
| | | import logging |
| | | import os |
| | | |
| | | import torch |
| | | |
| | | from funasr.train.distributed_utils import DistributedOption |
| | | from funasr.utils.build_dataclass import build_dataclass |
| | | |
| | | |
| | | def build_distributed(args): |
| | | distributed_option = build_dataclass(DistributedOption, args) |
| | | if args.use_pai: |
| | | distributed_option.init_options_pai() |
| | | distributed_option.init_torch_distributed_pai(args) |
| | | elif not args.simple_ddp: |
| | | distributed_option.init_torch_distributed(args) |
| | | elif args.distributed and args.simple_ddp: |
| | | distributed_option.init_torch_distributed_pai(args) |
| | | args.ngpu = torch.distributed.get_world_size() |
| | | |
| | | for handler in logging.root.handlers[:]: |
| | | logging.root.removeHandler(handler) |
| | | if not distributed_option.distributed or distributed_option.dist_rank == 0: |
| | | logging.basicConfig( |
| | | level="INFO", |
| | | format=f"[{os.uname()[1].split('.')[0]}]" |
| | | f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | else: |
| | | logging.basicConfig( |
| | | level="ERROR", |
| | | format=f"[{os.uname()[1].split('.')[0]}]" |
| | | f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size, |
| | | distributed_option.dist_rank, |
| | | distributed_option.local_rank)) |
| | | return distributed_option |