游雁
2023-11-23 dc682db808eb5f425f0dbed4c5e7feb0a334955f
funasr/bin/train.py
@@ -1,4 +1,6 @@
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import argparse
import logging
@@ -15,15 +17,18 @@
from funasr.build_utils.build_optimizer import build_optimizer
from funasr.build_utils.build_scheduler import build_scheduler
from funasr.build_utils.build_trainer import build_trainer
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.prepare_data import prepare_data
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
from funasr.modules.lora.utils import mark_only_lora_as_trainable
def get_parser():
@@ -58,18 +63,51 @@
    )
    parser.add_argument(
        "--dist_world_size",
        type=int,
        default=1,
        help="number of nodes for distributed training",
    )
    parser.add_argument(
        "--dist_rank",
        type=int,
        default=None,
        help="node rank for distributed training",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=None,
        help="local rank for distributed training",
    )
    parser.add_argument(
        "--dist_master_addr",
        default=None,
        type=str_or_none,
        help="The master address for distributed training. "
             "This value is used when dist_init_method == 'env://'",
    )
    parser.add_argument(
        "--dist_master_port",
        default=None,
        type=int_or_none,
        help="The master port for distributed training"
             "This value is used when dist_init_method == 'env://'",
    )
    parser.add_argument(
        "--dist_launcher",
        default=None,
        type=str_or_none,
        choices=["slurm", "mpi", None],
        help="The launcher type for distributed training",
    )
    parser.add_argument(
        "--multiprocessing_distributed",
        default=True,
        type=str2bool,
        help="Use multi-processing distributed training to launch "
             "N processes per node, which has N GPUs. This is the "
             "fastest way to use PyTorch for either single node or "
             "multi node data parallel training",
    )
    parser.add_argument(
        "--unused_parameters",
@@ -126,6 +164,7 @@
    )
    parser.add_argument(
        "--patience",
        type=int_or_none,
        default=None,
        help="Number of epochs to wait without improvement "
             "before stopping the training",
@@ -207,6 +246,12 @@
        help="Enable resuming if checkpoint is existing",
    )
    parser.add_argument(
        "--train_dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type for training.",
    )
    parser.add_argument(
        "--use_amp",
        type=str2bool,
        default=False,
@@ -219,13 +264,19 @@
             "training phase. If None is given, it is decided according the number "
             "of training samples automatically .",
    )
    parser.add_argument(
        "--use_tensorboard",
        type=str2bool,
        default=True,
        help="Enable tensorboard logging",
    )
    # pretrained model related
    parser.add_argument(
        "--init_param",
        type=str,
        action="append",
        default=[],
        nargs="*",
        help="Specify the file path used for initialization of parameters. "
             "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
             "where file_path is the model file path, "
@@ -251,7 +302,7 @@
        "--freeze_param",
        type=str,
        default=[],
        nargs="*",
        action="append",
        help="Freeze parameters",
    )
@@ -263,38 +314,41 @@
        help="whether to use dataloader for large dataset",
    )
    parser.add_argument(
        "--train_data_file",
        "--dataset_conf",
        action=NestedDictAction,
        default=dict(),
        help=f"The keyword arguments for dataset",
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        default=None,
        help="train_list for large dataset",
        help="root path of data",
    )
    parser.add_argument(
        "--valid_data_file",
        "--train_set",
        type=str,
        default="train",
        help="train dataset",
    )
    parser.add_argument(
        "--valid_set",
        type=str,
        default="validation",
        help="dev dataset",
    )
    parser.add_argument(
        "--data_file_names",
        type=str,
        default="wav.scp,text",
        help="input data files",
    )
    parser.add_argument(
        "--speed_perturb",
        type=float,
        nargs="+",
        default=None,
        help="valid_list for large dataset",
    )
    parser.add_argument(
        "--train_data_path_and_name_and_type",
        action="append",
        default=[],
        help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ",
    )
    parser.add_argument(
        "--valid_data_path_and_name_and_type",
        action="append",
        default=[],
    )
    parser.add_argument(
        "--train_shape_file",
        type=str, action="append",
        default=[],
    )
    parser.add_argument(
        "--valid_shape_file",
        type=str,
        action="append",
        default=[],
        help="speed perturb",
    )
    parser.add_argument(
        "--use_preprocessor",
@@ -425,6 +479,18 @@
        default=None,
        help="oss bucket.",
    )
    parser.add_argument(
        "--enable_lora",
        type=str2bool,
        default=False,
        help="Apply lora for finetuning.",
    )
    parser.add_argument(
        "--lora_bias",
        type=str,
        default="none",
        help="lora bias.",
    )
    return parser
@@ -434,7 +500,6 @@
    args, extra_task_params = parser.parse_known_args()
    if extra_task_params:
        args = build_args(args, parser, extra_task_params)
        # args = argparse.Namespace(**vars(args), **vars(task_args))
    # set random seed
    set_all_random_seed(args.seed)
@@ -465,6 +530,18 @@
    prepare_data(args, distributed_option)
    model = build_model(args)
    model = model.to(
        dtype=getattr(torch, args.train_dtype),
        device="cuda" if args.ngpu > 0 else "cpu",
    )
    if args.enable_lora:
        mark_only_lora_as_trainable(model, args.lora_bias)
    for t in args.freeze_param:
        for k, p in model.named_parameters():
            if k.startswith(t + ".") or k == t:
                logging.info(f"Setting {k}.requires_grad = False")
                p.requires_grad = False
    optimizers = build_optimizer(args, model=model)
    schedulers = build_scheduler(args, optimizers)
@@ -472,6 +549,7 @@
                                                                   distributed_option.dist_rank,
                                                                   distributed_option.local_rank))
    logging.info(pytorch_cudnn_version())
    logging.info("Args: {}".format(args))
    logging.info(model_summary(model))
    logging.info("Optimizer: {}".format(optimizers))
    logging.info("Scheduler: {}".format(schedulers))
@@ -488,6 +566,18 @@
            else:
                yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
    for p in args.init_param:
        logging.info(f"Loading pretrained params from {p}")
        load_pretrained_model(
            model=model,
            init_param=p,
            ignore_init_mismatch=args.ignore_init_mismatch,
            map_location=f"cuda:{torch.cuda.current_device()}"
            if args.ngpu > 0
            else "cpu",
            oss_bucket=args.oss_bucket,
        )
    # dataloader for training/validation
    train_dataloader, valid_dataloader = build_dataloader(args)