old mode 100644
new mode 100755
| | |
| | | #!/usr/bin/env python3 |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import argparse |
| | | import logging |
| | | import os |
| | | import sys |
| | | from io import BytesIO |
| | | |
| | | import torch |
| | | |
| | | from funasr.build_utils.build_args import build_args |
| | | from funasr.build_utils.build_dataloader import build_dataloader |
| | | from funasr.build_utils.build_distributed import build_distributed |
| | | from funasr.build_utils.build_model import build_model |
| | | from funasr.build_utils.build_optimizer import build_optimizer |
| | | from funasr.build_utils.build_scheduler import build_scheduler |
| | | from funasr.build_utils.build_trainer import build_trainer |
| | | from funasr.tokenizer.phoneme_tokenizer import g2p_choices |
| | | from funasr.torch_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.torch_utils.model_summary import model_summary |
| | | from funasr.torch_utils.pytorch_version import pytorch_cudnn_version |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.build_distributed import build_distributed |
| | | from funasr.utils.nested_dict_action import NestedDictAction |
| | | from funasr.utils.prepare_data import prepare_data |
| | | from funasr.utils.types import int_or_none |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
| | | from funasr.modules.lora.utils import mark_only_lora_as_trainable |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = config_argparse.ArgumentParser( |
| | | parser = argparse.ArgumentParser( |
| | | description="FunASR Common Training Parser", |
| | | ) |
| | | |
| | |
| | | help="The number of gpus. 0 indicates CPU mode", |
| | | ) |
| | | parser.add_argument("--seed", type=int, default=0, help="Random seed") |
| | | parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks") |
| | | |
| | | # ddp related |
| | | parser.add_argument( |
| | |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_world_size", |
| | | default=None, |
| | | type=int, |
| | | default=1, |
| | | help="number of nodes for distributed training", |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_rank", |
| | | type=int, |
| | | default=None, |
| | | help="node rank for distributed training", |
| | | ) |
| | | parser.add_argument( |
| | | "--local_rank", |
| | | type=int, |
| | | default=None, |
| | | help="local rank for distributed training", |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_master_addr", |
| | | default=None, |
| | | type=str_or_none, |
| | | help="The master address for distributed training. " |
| | | "This value is used when dist_init_method == 'env://'", |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_master_port", |
| | | default=None, |
| | | type=int_or_none, |
| | | help="The master port for distributed training" |
| | | "This value is used when dist_init_method == 'env://'", |
| | | ) |
| | | parser.add_argument( |
| | | "--dist_launcher", |
| | | default=None, |
| | | type=str_or_none, |
| | | choices=["slurm", "mpi", None], |
| | | help="The launcher type for distributed training", |
| | | ) |
| | | parser.add_argument( |
| | | "--multiprocessing_distributed", |
| | | default=True, |
| | | type=str2bool, |
| | | help="Use multi-processing distributed training to launch " |
| | | "N processes per node, which has N GPUs. This is the " |
| | | "fastest way to use PyTorch for either single node or " |
| | | "multi node data parallel training", |
| | | ) |
| | | parser.add_argument( |
| | | "--unused_parameters", |
| | |
| | | default=False, |
| | | help="Whether to use the find_unused_parameters in " |
| | | "torch.nn.parallel.DistributedDataParallel ", |
| | | ) |
| | | parser.add_argument( |
| | | "--gpu_id", |
| | | type=int, |
| | | default=0, |
| | | help="local gpu id.", |
| | | ) |
| | | |
| | | # cudnn related |
| | |
| | | ) |
| | | parser.add_argument( |
| | | "--patience", |
| | | type=int_or_none, |
| | | default=None, |
| | | help="Number of epochs to wait without improvement " |
| | | "before stopping the training", |
| | |
| | | help="Enable resuming if checkpoint is existing", |
| | | ) |
| | | parser.add_argument( |
| | | "--train_dtype", |
| | | default="float32", |
| | | choices=["float16", "float32", "float64"], |
| | | help="Data type for training.", |
| | | ) |
| | | parser.add_argument( |
| | | "--use_amp", |
| | | type=str2bool, |
| | | default=False, |
| | |
| | | "training phase. If None is given, it is decided according the number " |
| | | "of training samples automatically .", |
| | | ) |
| | | parser.add_argument( |
| | | "--use_tensorboard", |
| | | type=str2bool, |
| | | default=True, |
| | | help="Enable tensorboard logging", |
| | | ) |
| | | |
| | | # pretrained model related |
| | | parser.add_argument( |
| | | "--init_param", |
| | | type=str, |
| | | action="append", |
| | | default=[], |
| | | nargs="*", |
| | | help="Specify the file path used for initialization of parameters. " |
| | | "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', " |
| | | "where file_path is the model file path, " |
| | |
| | | "--freeze_param", |
| | | type=str, |
| | | default=[], |
| | | nargs="*", |
| | | action="append", |
| | | help="Freeze parameters", |
| | | ) |
| | | |
| | |
| | | help="whether to use dataloader for large dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--train_data_file", |
| | | "--dataset_conf", |
| | | action=NestedDictAction, |
| | | default=dict(), |
| | | help=f"The keyword arguments for dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--data_dir", |
| | | type=str, |
| | | default=None, |
| | | help="train_list for large dataset", |
| | | help="root path of data", |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_data_file", |
| | | "--train_set", |
| | | type=str, |
| | | default="train", |
| | | help="train dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_set", |
| | | type=str, |
| | | default="validation", |
| | | help="dev dataset", |
| | | ) |
| | | parser.add_argument( |
| | | "--data_file_names", |
| | | type=str, |
| | | default="wav.scp,text", |
| | | help="input data files", |
| | | ) |
| | | parser.add_argument( |
| | | "--speed_perturb", |
| | | type=float, |
| | | nargs="+", |
| | | default=None, |
| | | help="valid_list for large dataset", |
| | | help="speed perturb", |
| | | ) |
| | | parser.add_argument( |
| | | "--train_data_path_and_name_and_type", |
| | | action="append", |
| | | default=[], |
| | | help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ", |
| | | "--use_preprocessor", |
| | | type=str2bool, |
| | | default=True, |
| | | help="Apply preprocessing to data or not", |
| | | ) |
| | | |
| | | # optimization related |
| | | parser.add_argument( |
| | | "--optim", |
| | | type=lambda x: x.lower(), |
| | | default="adam", |
| | | help="The optimizer type", |
| | | ) |
| | | parser.add_argument( |
| | | "--valid_data_path_and_name_and_type", |
| | | action="append", |
| | | default=[], |
| | | "--optim_conf", |
| | | action=NestedDictAction, |
| | | default=dict(), |
| | | help="The keyword arguments for optimizer", |
| | | ) |
| | | parser.add_argument( |
| | | "--scheduler", |
| | | type=lambda x: str_or_none(x.lower()), |
| | | default=None, |
| | | help="The lr scheduler type", |
| | | ) |
| | | parser.add_argument( |
| | | "--scheduler_conf", |
| | | action=NestedDictAction, |
| | | default=dict(), |
| | | help="The keyword arguments for lr scheduler", |
| | | ) |
| | | |
| | | # most task related |
| | | parser.add_argument( |
| | | "--init", |
| | | type=lambda x: str_or_none(x.lower()), |
| | | default=None, |
| | | help="The initialization method", |
| | | choices=[ |
| | | "chainer", |
| | | "xavier_uniform", |
| | | "xavier_normal", |
| | | "kaiming_uniform", |
| | | "kaiming_normal", |
| | | None, |
| | | ], |
| | | ) |
| | | parser.add_argument( |
| | | "--token_list", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="A text mapping int-id to token", |
| | | ) |
| | | parser.add_argument( |
| | | "--token_type", |
| | | type=str, |
| | | default="bpe", |
| | | choices=["bpe", "char", "word"], |
| | | help="", |
| | | ) |
| | | parser.add_argument( |
| | | "--bpemodel", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The model file fo sentencepiece", |
| | | ) |
| | | parser.add_argument( |
| | | "--cleaner", |
| | | type=str_or_none, |
| | | choices=[None, "tacotron", "jaconv", "vietnamese"], |
| | | default=None, |
| | | help="Apply text cleaning", |
| | | ) |
| | | parser.add_argument( |
| | | "--g2p", |
| | | type=str_or_none, |
| | | choices=g2p_choices, |
| | | default=None, |
| | | help="Specify g2p method if --token_type=phn", |
| | | ) |
| | | |
| | | # pai related |
| | |
| | | default=None, |
| | | help="oss bucket.", |
| | | ) |
| | | |
| | | # task related |
| | | parser.add_argument("--task_name", help="for different task") |
| | | parser.add_argument( |
| | | "--enable_lora", |
| | | type=str2bool, |
| | | default=False, |
| | | help="Apply lora for finetuning.", |
| | | ) |
| | | parser.add_argument( |
| | | "--lora_bias", |
| | | type=str, |
| | | default="none", |
| | | help="lora bias.", |
| | | ) |
| | | |
| | | return parser |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | parser = get_parser() |
| | | args = parser.parse_args() |
| | | args, extra_task_params = parser.parse_known_args() |
| | | if extra_task_params: |
| | | args = build_args(args, parser, extra_task_params) |
| | | |
| | | # set random seed |
| | | set_all_random_seed(args.seed) |
| | | torch.backends.cudnn.enabled = args.cudnn_enabled |
| | | torch.backends.cudnn.benchmark = args.cudnn_benchmark |
| | | torch.backends.cudnn.deterministic = args.cudnn_deterministic |
| | | |
| | | # ddp init |
| | | args.distributed = args.dist_world_size > 1 |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
| | | args.distributed = args.ngpu > 1 or args.dist_world_size > 1 |
| | | distributed_option = build_distributed(args) |
| | | |
| | | # for logging |
| | | if not distributed_option.distributed or distributed_option.dist_rank == 0: |
| | | logging.basicConfig( |
| | | level="INFO", |
| | |
| | | format=f"[{os.uname()[1].split('.')[0]}]" |
| | | f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size, |
| | | distributed_option.dist_rank, |
| | | distributed_option.local_rank)) |
| | | |
| | | # prepare files for dataloader |
| | | prepare_data(args, distributed_option) |
| | | |
| | | set_all_random_seed(args.seed) |
| | | torch.backends.cudnn.enabled = args.cudnn_enabled |
| | | torch.backends.cudnn.benchmark = args.cudnn_benchmark |
| | | torch.backends.cudnn.deterministic = args.cudnn_deterministic |
| | | model = build_model(args) |
| | | model = model.to( |
| | | dtype=getattr(torch, args.train_dtype), |
| | | device="cuda" if args.ngpu > 0 else "cpu", |
| | | ) |
| | | if args.enable_lora: |
| | | mark_only_lora_as_trainable(model, args.lora_bias) |
| | | for t in args.freeze_param: |
| | | for k, p in model.named_parameters(): |
| | | if k.startswith(t + ".") or k == t: |
| | | logging.info(f"Setting {k}.requires_grad = False") |
| | | p.requires_grad = False |
| | | |
| | | optimizers = build_optimizer(args, model=model) |
| | | schedulers = build_scheduler(args, optimizers) |
| | | |
| | | logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size, |
| | | distributed_option.dist_rank, |
| | | distributed_option.local_rank)) |
| | | logging.info(pytorch_cudnn_version()) |
| | | logging.info("Args: {}".format(args)) |
| | | logging.info(model_summary(model)) |
| | | logging.info("Optimizer: {}".format(optimizers)) |
| | | logging.info("Scheduler: {}".format(schedulers)) |
| | | |
| | | # dump args to config.yaml |
| | | if not distributed_option.distributed or distributed_option.dist_rank == 0: |
| | | os.makedirs(args.output_dir, exist_ok=True) |
| | | with open(os.path.join(args.output_dir, "config.yaml"), "w") as f: |
| | | logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml")) |
| | | if args.use_pai: |
| | | buffer = BytesIO() |
| | | torch.save({"config": vars(args)}, buffer) |
| | | args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue()) |
| | | else: |
| | | yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False) |
| | | |
| | | for p in args.init_param: |
| | | logging.info(f"Loading pretrained params from {p}") |
| | | load_pretrained_model( |
| | | model=model, |
| | | init_param=p, |
| | | ignore_init_mismatch=args.ignore_init_mismatch, |
| | | map_location=f"cuda:{torch.cuda.current_device()}" |
| | | if args.ngpu > 0 |
| | | else "cpu", |
| | | oss_bucket=args.oss_bucket, |
| | | ) |
| | | |
| | | # dataloader for training/validation |
| | | train_dataloader, valid_dataloader = build_dataloader(args) |
| | | |
| | | # Trainer, including model, optimizers, etc. |
| | | trainer = build_trainer( |
| | | args=args, |
| | | model=model, |
| | | optimizers=optimizers, |
| | | schedulers=schedulers, |
| | | train_dataloader=train_dataloader, |
| | | valid_dataloader=valid_dataloader, |
| | | distributed_option=distributed_option |
| | | ) |
| | | |
| | | trainer.run() |