游雁
2024-01-16 bb97d3ed19ee3a219e67b9568d662df489aa2823
funasr/bin/train.py
old mode 100755 new mode 100644
@@ -1,566 +1,182 @@
#!/usr/bin/env python3
import argparse
import logging
import os
import sys
from io import BytesIO
import torch
import hydra
import logging
import argparse
from io import BytesIO
import torch.distributed as dist
from collections.abc import Sequence
from omegaconf import DictConfig, OmegaConf
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from funasr.build_utils.build_args import build_args
from funasr.build_utils.build_dataloader import build_dataloader
from funasr.build_utils.build_distributed import build_distributed
from funasr.build_utils.build_model import build_model
from funasr.build_utils.build_optimizer import build_optimizer
from funasr.build_utils.build_scheduler import build_scheduler
from funasr.build_utils.build_trainer import build_trainer
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.prepare_data import prepare_data
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
from funasr.register import tables
from funasr.optimizers import optim_classes
from funasr.train_utils.trainer import Trainer
from funasr.schedulers import scheduler_classes
from funasr.train_utils.initialize import initialize
from funasr.download.download_from_hub import download_model
from funasr.models.lora.utils import mark_only_lora_as_trainable
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
# from funasr.tokenizer.build_tokenizer import build_tokenizer
# from funasr.tokenizer.token_id_converter import TokenIDConverter
# from funasr.tokenizer.funtoken import build_tokenizer
def get_parser():
    parser = argparse.ArgumentParser(
        description="FunASR Common Training Parser",
    )
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
    if kwargs.get("debug", False):
        import pdb; pdb.set_trace()
    # common configuration
    parser.add_argument("--output_dir", help="model save path")
    parser.add_argument(
        "--ngpu",
        type=int,
        default=0,
        help="The number of gpus. 0 indicates CPU mode",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
    assert "model" in kwargs
    if "model_conf" not in kwargs:
        logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
    # ddp related
    parser.add_argument(
        "--dist_backend",
        default="nccl",
        type=str,
        help="distributed backend",
    )
    parser.add_argument(
        "--dist_init_method",
        type=str,
        default="env://",
        help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
             '"WORLD_SIZE", and "RANK" are referred.',
    )
    parser.add_argument(
        "--dist_world_size",
        type=int,
        default=1,
        help="number of nodes for distributed training",
    )
    parser.add_argument(
        "--dist_rank",
        type=int,
        default=None,
        help="node rank for distributed training",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=None,
        help="local rank for distributed training",
    )
    parser.add_argument(
        "--dist_master_addr",
        default=None,
        type=str_or_none,
        help="The master address for distributed training. "
             "This value is used when dist_init_method == 'env://'",
    )
    parser.add_argument(
        "--dist_master_port",
        default=None,
        type=int_or_none,
        help="The master port for distributed training"
             "This value is used when dist_init_method == 'env://'",
    )
    parser.add_argument(
        "--dist_launcher",
        default=None,
        type=str_or_none,
        choices=["slurm", "mpi", None],
        help="The launcher type for distributed training",
    )
    parser.add_argument(
        "--multiprocessing_distributed",
        default=True,
        type=str2bool,
        help="Use multi-processing distributed training to launch "
             "N processes per node, which has N GPUs. This is the "
             "fastest way to use PyTorch for either single node or "
             "multi node data parallel training",
    )
    parser.add_argument(
        "--unused_parameters",
        type=str2bool,
        default=False,
        help="Whether to use the find_unused_parameters in "
             "torch.nn.parallel.DistributedDataParallel ",
    )
    parser.add_argument(
        "--gpu_id",
        type=int,
        default=0,
        help="local gpu id.",
    )
    # cudnn related
    parser.add_argument(
        "--cudnn_enabled",
        type=str2bool,
        default=torch.backends.cudnn.enabled,
        help="Enable CUDNN",
    )
    parser.add_argument(
        "--cudnn_benchmark",
        type=str2bool,
        default=torch.backends.cudnn.benchmark,
        help="Enable cudnn-benchmark mode",
    )
    parser.add_argument(
        "--cudnn_deterministic",
        type=str2bool,
        default=True,
        help="Enable cudnn-deterministic mode",
    )
    # trainer related
    parser.add_argument(
        "--max_epoch",
        type=int,
        default=40,
        help="The maximum number epoch to train",
    )
    parser.add_argument(
        "--max_update",
        type=int,
        default=sys.maxsize,
        help="The maximum number update step to train",
    )
    parser.add_argument(
        "--batch_interval",
        type=int,
        default=10000,
        help="The batch interval for saving model.",
    )
    parser.add_argument(
        "--patience",
        type=int_or_none,
        default=None,
        help="Number of epochs to wait without improvement "
             "before stopping the training",
    )
    parser.add_argument(
        "--val_scheduler_criterion",
        type=str,
        nargs=2,
        default=("valid", "loss"),
        help="The criterion used for the value given to the lr scheduler. "
             'Give a pair referring the phase, "train" or "valid",'
             'and the criterion name. The mode specifying "min" or "max" can '
             "be changed by --scheduler_conf",
    )
    parser.add_argument(
        "--early_stopping_criterion",
        type=str,
        nargs=3,
        default=("valid", "loss", "min"),
        help="The criterion used for judging of early stopping. "
             'Give a pair referring the phase, "train" or "valid",'
             'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
    )
    parser.add_argument(
        "--best_model_criterion",
        nargs="+",
        default=[
            ("train", "loss", "min"),
            ("valid", "loss", "min"),
            ("train", "acc", "max"),
            ("valid", "acc", "max"),
        ],
        help="The criterion used for judging of the best model. "
             'Give a pair referring the phase, "train" or "valid",'
             'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
    )
    parser.add_argument(
        "--keep_nbest_models",
        type=int,
        nargs="+",
        default=[10],
        help="Remove previous snapshots excluding the n-best scored epochs",
    )
    parser.add_argument(
        "--nbest_averaging_interval",
        type=int,
        default=0,
        help="The epoch interval to apply model averaging and save nbest models",
    )
    parser.add_argument(
        "--grad_clip",
        type=float,
        default=5.0,
        help="Gradient norm threshold to clip",
    )
    parser.add_argument(
        "--grad_clip_type",
        type=float,
        default=2.0,
        help="The type of the used p-norm for gradient clip. Can be inf",
    )
    parser.add_argument(
        "--grad_noise",
        type=str2bool,
        default=False,
        help="The flag to switch to use noise injection to "
             "gradients during training",
    )
    parser.add_argument(
        "--accum_grad",
        type=int,
        default=1,
        help="The number of gradient accumulation",
    )
    parser.add_argument(
        "--resume",
        type=str2bool,
        default=False,
        help="Enable resuming if checkpoint is existing",
    )
    parser.add_argument(
        "--train_dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type for training.",
    )
    parser.add_argument(
        "--use_amp",
        type=str2bool,
        default=False,
        help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
    )
    parser.add_argument(
        "--log_interval",
        default=None,
        help="Show the logs every the number iterations in each epochs at the "
             "training phase. If None is given, it is decided according the number "
             "of training samples automatically .",
    )
    parser.add_argument(
        "--use_tensorboard",
        type=str2bool,
        default=True,
        help="Enable tensorboard logging",
    )
    # pretrained model related
    parser.add_argument(
        "--init_param",
        type=str,
        default=[],
        nargs="*",
        help="Specify the file path used for initialization of parameters. "
             "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
             "where file_path is the model file path, "
             "src_key specifies the key of model states to be used in the model file, "
             "dst_key specifies the attribute of the model to be initialized, "
             "and exclude_keys excludes keys of model states for the initialization."
             "e.g.\n"
             "  # Load all parameters"
             "  --init_param some/where/model.pb\n"
             "  # Load only decoder parameters"
             "  --init_param some/where/model.pb:decoder:decoder\n"
             "  # Load only decoder parameters excluding decoder.embed"
             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
    )
    parser.add_argument(
        "--ignore_init_mismatch",
        type=str2bool,
        default=False,
        help="Ignore size mismatch when loading pre-trained model",
    )
    parser.add_argument(
        "--freeze_param",
        type=str,
        default=[],
        nargs="*",
        help="Freeze parameters",
    )
    # dataset related
    parser.add_argument(
        "--dataset_type",
        type=str,
        default="small",
        help="whether to use dataloader for large dataset",
    )
    parser.add_argument(
        "--dataset_conf",
        action=NestedDictAction,
        default=dict(),
        help=f"The keyword arguments for dataset",
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        default=None,
        help="root path of data",
    )
    parser.add_argument(
        "--train_set",
        type=str,
        default="train",
        help="train dataset",
    )
    parser.add_argument(
        "--valid_set",
        type=str,
        default="validation",
        help="dev dataset",
    )
    parser.add_argument(
        "--speed_perturb",
        type=float,
        nargs="+",
        default=None,
        help="speed perturb",
    )
    parser.add_argument(
        "--use_preprocessor",
        type=str2bool,
        default=True,
        help="Apply preprocessing to data or not",
    )
    # optimization related
    parser.add_argument(
        "--optim",
        type=lambda x: x.lower(),
        default="adam",
        help="The optimizer type",
    )
    parser.add_argument(
        "--optim_conf",
        action=NestedDictAction,
        default=dict(),
        help="The keyword arguments for optimizer",
    )
    parser.add_argument(
        "--scheduler",
        type=lambda x: str_or_none(x.lower()),
        default=None,
        help="The lr scheduler type",
    )
    parser.add_argument(
        "--scheduler_conf",
        action=NestedDictAction,
        default=dict(),
        help="The keyword arguments for lr scheduler",
    )
    # most task related
    parser.add_argument(
        "--init",
        type=lambda x: str_or_none(x.lower()),
        default=None,
        help="The initialization method",
        choices=[
            "chainer",
            "xavier_uniform",
            "xavier_normal",
            "kaiming_uniform",
            "kaiming_normal",
            None,
        ],
    )
    parser.add_argument(
        "--token_list",
        type=str_or_none,
        default=None,
        help="A text mapping int-id to token",
    )
    parser.add_argument(
        "--token_type",
        type=str,
        default="bpe",
        choices=["bpe", "char", "word"],
        help="",
    )
    parser.add_argument(
        "--bpemodel",
        type=str_or_none,
        default=None,
        help="The model file fo sentencepiece",
    )
    parser.add_argument(
        "--cleaner",
        type=str_or_none,
        choices=[None, "tacotron", "jaconv", "vietnamese"],
        default=None,
        help="Apply text cleaning",
    )
    parser.add_argument(
        "--g2p",
        type=str_or_none,
        choices=g2p_choices,
        default=None,
        help="Specify g2p method if --token_type=phn",
    )
    # pai related
    parser.add_argument(
        "--use_pai",
        type=str2bool,
        default=False,
        help="flag to indicate whether training on PAI",
    )
    parser.add_argument(
        "--simple_ddp",
        type=str2bool,
        default=False,
    )
    parser.add_argument(
        "--num_worker_count",
        type=int,
        default=1,
        help="The number of machines on PAI.",
    )
    parser.add_argument(
        "--access_key_id",
        type=str,
        default=None,
        help="The username for oss.",
    )
    parser.add_argument(
        "--access_key_secret",
        type=str,
        default=None,
        help="The password for oss.",
    )
    parser.add_argument(
        "--endpoint",
        type=str,
        default=None,
        help="The endpoint for oss.",
    )
    parser.add_argument(
        "--bucket_name",
        type=str,
        default=None,
        help="The bucket name for oss.",
    )
    parser.add_argument(
        "--oss_bucket",
        default=None,
        help="oss bucket.",
    )
    return parser
    main(**kwargs)
if __name__ == '__main__':
    parser = get_parser()
    args, extra_task_params = parser.parse_known_args()
    if extra_task_params:
        args = build_args(args, parser, extra_task_params)
def main(**kwargs):
    # preprocess_config(kwargs)
    # import pdb; pdb.set_trace()
    # set random seed
    set_all_random_seed(args.seed)
    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark
    torch.backends.cudnn.deterministic = args.cudnn_deterministic
    tables.print()
    set_all_random_seed(kwargs.get("seed", 0))
    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    # Check if we are using DDP or FSDP
    use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
    use_fsdp = kwargs.get("use_fsdp", None)
    if use_ddp or use_fsdp:
        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
        torch.cuda.set_device(local_rank)
    # save config.yaml
    if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
        os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
        yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
        OmegaConf.save(config=kwargs, f=yaml_file)
        logging.info("config.yaml is saved to: %s", yaml_file)
    # ddp init
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    args.distributed = args.ngpu > 1 or args.dist_world_size > 1
    distributed_option = build_distributed(args)
    tokenizer = kwargs.get("tokenizer", None)
    if tokenizer is not None:
        tokenizer_class = tables.tokenizer_classes.get(tokenizer)
        tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
        kwargs["tokenizer"] = tokenizer
    # build frontend if frontend is none None
    frontend = kwargs.get("frontend", None)
    if frontend is not None:
        frontend_class = tables.frontend_classes.get(frontend)
        frontend = frontend_class(**kwargs["frontend_conf"])
        kwargs["frontend"] = frontend
        kwargs["input_size"] = frontend.output_size()
    # import pdb;
    # pdb.set_trace()
    # build model
    model_class = tables.model_classes.get(kwargs["model"])
    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
    # for logging
    if not distributed_option.distributed or distributed_option.dist_rank == 0:
        logging.basicConfig(
            level="INFO",
            format=f"[{os.uname()[1].split('.')[0]}]"
                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    # init_param
    init_param = kwargs.get("init_param", None)
    if init_param is not None:
        if not isinstance(init_param, (list, tuple)):
            init_param = (init_param,)
        logging.info("init_param is not None: %s", init_param)
        for p in init_param:
            logging.info(f"Loading pretrained params from {p}")
            load_pretrained_model(
                model=model,
                path=p,
                ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
                oss_bucket=kwargs.get("oss_bucket", None),
                scope_map=kwargs.get("scope_map", None),
                excludes=kwargs.get("excludes", None),
            )
    else:
        logging.basicConfig(
            level="ERROR",
            format=f"[{os.uname()[1].split('.')[0]}]"
                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        initialize(model, kwargs.get("init", "kaiming_normal"))
    # prepare files for dataloader
    prepare_data(args, distributed_option)
    model = build_model(args)
    model = model.to(
        dtype=getattr(torch, args.train_dtype),
        device="cuda" if args.ngpu > 0 else "cpu",
    )
    optimizers = build_optimizer(args, model=model)
    schedulers = build_scheduler(args, optimizers)
    # freeze_param
    freeze_param = kwargs.get("freeze_param", None)
    if freeze_param is not None:
        freeze_param = eval(freeze_param)
        if isinstance(freeze_param, Sequence):
            freeze_param = (freeze_param,)
        logging.info("freeze_param is not None: %s", freeze_param)
        for t in freeze_param:
            for k, p in model.named_parameters():
                if k.startswith(t + ".") or k == t:
                    logging.info(f"Setting {k}.requires_grad = False")
                    p.requires_grad = False
    logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
                                                                   distributed_option.dist_rank,
                                                                   distributed_option.local_rank))
    logging.info(pytorch_cudnn_version())
    logging.info("Args: {}".format(args))
    logging.info(model_summary(model))
    logging.info("Optimizer: {}".format(optimizers))
    logging.info("Scheduler: {}".format(schedulers))
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(model, device_ids=[local_rank],
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
    elif use_fsdp:
        model = FSDP(model).cuda(local_rank)
    else:
        model = model.to(device=kwargs.get("device", "cuda"))
    # optim
    optim = kwargs.get("optim", "adam")
    assert optim in optim_classes
    optim_class = optim_classes.get(optim)
    optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
    # scheduler
    scheduler = kwargs.get("scheduler", "warmuplr")
    assert scheduler in scheduler_classes
    scheduler_class = scheduler_classes.get(scheduler)
    scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
    # dump args to config.yaml
    if not distributed_option.distributed or distributed_option.dist_rank == 0:
        os.makedirs(args.output_dir, exist_ok=True)
        with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
            logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
            if args.use_pai:
                buffer = BytesIO()
                torch.save({"config": vars(args)}, buffer)
                args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
            else:
                yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
    # import pdb;
    # pdb.set_trace()
    # dataset
    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
    for p in args.init_param:
        logging.info(f"Loading pretrained params from {p}")
        load_pretrained_model(
            model=model,
            init_param=p,
            ignore_init_mismatch=args.ignore_init_mismatch,
            map_location=f"cuda:{torch.cuda.current_device()}"
            if args.ngpu > 0
            else "cpu",
            oss_bucket=args.oss_bucket,
        )
    # dataloader
    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
    batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
    if batch_sampler is not None:
        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
                                                collate_fn=dataset_tr.collator,
                                                batch_sampler=batch_sampler,
                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
                                                pin_memory=True)
    # dataloader for training/validation
    train_dataloader, valid_dataloader = build_dataloader(args)
    # Trainer, including model, optimizers, etc.
    trainer = build_trainer(
        args=args,
    trainer = Trainer(
        model=model,
        optimizers=optimizers,
        schedulers=schedulers,
        train_dataloader=train_dataloader,
        valid_dataloader=valid_dataloader,
        distributed_option=distributed_option
        optim=optim,
        scheduler=scheduler,
        dataloader_train=dataloader_tr,
        dataloader_val=None,
        local_rank=local_rank,
        use_ddp=use_ddp,
        use_fsdp=use_fsdp,
        **kwargs.get("train_conf"),
    )
    trainer.run()
    if use_ddp or use_fsdp:
        torch.distributed.destroy_process_group()
if __name__ == "__main__":
    main_hydra()