From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example

---
 funasr/bin/train.py |  513 +++++++++++++++++++-------------------------------------
 1 files changed, 177 insertions(+), 336 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index dbfebd7..d916509 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,350 +1,191 @@
-import logging
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+
 import os
 import sys
-
 import torch
+import hydra
+import logging
+import argparse
+from io import BytesIO
+import torch.distributed as dist
+from collections.abc import Sequence
+from omegaconf import DictConfig, OmegaConf
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.build_distributed import build_distributed
-from funasr.utils.prepare_data import prepare_data
-from funasr.utils.types import str2bool
+from funasr.register import tables
+from funasr.optimizers import optim_classes
+from funasr.train_utils.trainer import Trainer
+from funasr.schedulers import scheduler_classes
+from funasr.train_utils.initialize import initialize
+from funasr.download.download_from_hub import download_model
+from funasr.models.lora.utils import mark_only_lora_as_trainable
+from funasr.train_utils.set_all_random_seed import set_all_random_seed
+from funasr.train_utils.load_pretrained_model import load_pretrained_model
+# from funasr.tokenizer.build_tokenizer import build_tokenizer
+# from funasr.tokenizer.token_id_converter import TokenIDConverter
+# from funasr.tokenizer.funtoken import build_tokenizer
 
 
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="FunASR Common Training Parser",
-    )
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(kwargs: DictConfig):
+    if kwargs.get("debug", False):
+        import pdb; pdb.set_trace()
 
-    # common configuration
-    parser.add_argument("--output_dir", help="model save path")
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    assert "model" in kwargs
+    if "model_conf" not in kwargs:
+        logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+        kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
+    
 
-    # ddp related
-    parser.add_argument(
-        "--dist_backend",
-        default="nccl",
-        type=str,
-        help="distributed backend",
-    )
-    parser.add_argument(
-        "--dist_init_method",
-        type=str,
-        default="env://",
-        help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
-             '"WORLD_SIZE", and "RANK" are referred.',
-    )
-    parser.add_argument(
-        "--dist_world_size",
-        default=None,
-        help="number of nodes for distributed training",
-    )
-    parser.add_argument(
-        "--dist_rank",
-        default=None,
-        help="node rank for distributed training",
-    )
-    parser.add_argument(
-        "--local_rank",
-        default=None,
-        help="local rank for distributed training",
-    )
-    parser.add_argument(
-        "--unused_parameters",
-        type=str2bool,
-        default=False,
-        help="Whether to use the find_unused_parameters in "
-             "torch.nn.parallel.DistributedDataParallel ",
-    )
-
-    # cudnn related
-    parser.add_argument(
-        "--cudnn_enabled",
-        type=str2bool,
-        default=torch.backends.cudnn.enabled,
-        help="Enable CUDNN",
-    )
-    parser.add_argument(
-        "--cudnn_benchmark",
-        type=str2bool,
-        default=torch.backends.cudnn.benchmark,
-        help="Enable cudnn-benchmark mode",
-    )
-    parser.add_argument(
-        "--cudnn_deterministic",
-        type=str2bool,
-        default=True,
-        help="Enable cudnn-deterministic mode",
-    )
-
-    # trainer related
-    parser.add_argument(
-        "--max_epoch",
-        type=int,
-        default=40,
-        help="The maximum number epoch to train",
-    )
-    parser.add_argument(
-        "--max_update",
-        type=int,
-        default=sys.maxsize,
-        help="The maximum number update step to train",
-    )
-    parser.add_argument(
-        "--batch_interval",
-        type=int,
-        default=10000,
-        help="The batch interval for saving model.",
-    )
-    parser.add_argument(
-        "--patience",
-        default=None,
-        help="Number of epochs to wait without improvement "
-             "before stopping the training",
-    )
-    parser.add_argument(
-        "--val_scheduler_criterion",
-        type=str,
-        nargs=2,
-        default=("valid", "loss"),
-        help="The criterion used for the value given to the lr scheduler. "
-             'Give a pair referring the phase, "train" or "valid",'
-             'and the criterion name. The mode specifying "min" or "max" can '
-             "be changed by --scheduler_conf",
-    )
-    parser.add_argument(
-        "--early_stopping_criterion",
-        type=str,
-        nargs=3,
-        default=("valid", "loss", "min"),
-        help="The criterion used for judging of early stopping. "
-             'Give a pair referring the phase, "train" or "valid",'
-             'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
-    )
-    parser.add_argument(
-        "--best_model_criterion",
-        nargs="+",
-        default=[
-            ("train", "loss", "min"),
-            ("valid", "loss", "min"),
-            ("train", "acc", "max"),
-            ("valid", "acc", "max"),
-        ],
-        help="The criterion used for judging of the best model. "
-             'Give a pair referring the phase, "train" or "valid",'
-             'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
-    )
-    parser.add_argument(
-        "--keep_nbest_models",
-        type=int,
-        nargs="+",
-        default=[10],
-        help="Remove previous snapshots excluding the n-best scored epochs",
-    )
-    parser.add_argument(
-        "--nbest_averaging_interval",
-        type=int,
-        default=0,
-        help="The epoch interval to apply model averaging and save nbest models",
-    )
-    parser.add_argument(
-        "--grad_clip",
-        type=float,
-        default=5.0,
-        help="Gradient norm threshold to clip",
-    )
-    parser.add_argument(
-        "--grad_clip_type",
-        type=float,
-        default=2.0,
-        help="The type of the used p-norm for gradient clip. Can be inf",
-    )
-    parser.add_argument(
-        "--grad_noise",
-        type=str2bool,
-        default=False,
-        help="The flag to switch to use noise injection to "
-             "gradients during training",
-    )
-    parser.add_argument(
-        "--accum_grad",
-        type=int,
-        default=1,
-        help="The number of gradient accumulation",
-    )
-    parser.add_argument(
-        "--resume",
-        type=str2bool,
-        default=False,
-        help="Enable resuming if checkpoint is existing",
-    )
-    parser.add_argument(
-        "--use_amp",
-        type=str2bool,
-        default=False,
-        help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
-    )
-    parser.add_argument(
-        "--log_interval",
-        default=None,
-        help="Show the logs every the number iterations in each epochs at the "
-             "training phase. If None is given, it is decided according the number "
-             "of training samples automatically .",
-    )
-
-    # pretrained model related
-    parser.add_argument(
-        "--init_param",
-        type=str,
-        default=[],
-        nargs="*",
-        help="Specify the file path used for initialization of parameters. "
-             "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
-             "where file_path is the model file path, "
-             "src_key specifies the key of model states to be used in the model file, "
-             "dst_key specifies the attribute of the model to be initialized, "
-             "and exclude_keys excludes keys of model states for the initialization."
-             "e.g.\n"
-             "  # Load all parameters"
-             "  --init_param some/where/model.pb\n"
-             "  # Load only decoder parameters"
-             "  --init_param some/where/model.pb:decoder:decoder\n"
-             "  # Load only decoder parameters excluding decoder.embed"
-             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
-             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
-    )
-    parser.add_argument(
-        "--ignore_init_mismatch",
-        type=str2bool,
-        default=False,
-        help="Ignore size mismatch when loading pre-trained model",
-    )
-    parser.add_argument(
-        "--freeze_param",
-        type=str,
-        default=[],
-        nargs="*",
-        help="Freeze parameters",
-    )
-
-    # dataset related
-    parser.add_argument(
-        "--dataset_type",
-        type=str,
-        default="small",
-        help="whether to use dataloader for large dataset",
-    )
-    parser.add_argument(
-        "--train_data_file",
-        type=str,
-        default=None,
-        help="train_list for large dataset",
-    )
-    parser.add_argument(
-        "--valid_data_file",
-        type=str,
-        default=None,
-        help="valid_list for large dataset",
-    )
-    parser.add_argument(
-        "--train_data_path_and_name_and_type",
-        action="append",
-        default=[],
-        help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ",
-    )
-    parser.add_argument(
-        "--valid_data_path_and_name_and_type",
-        action="append",
-        default=[],
-    )
-
-    # pai related
-    parser.add_argument(
-        "--use_pai",
-        type=str2bool,
-        default=False,
-        help="flag to indicate whether training on PAI",
-    )
-    parser.add_argument(
-        "--simple_ddp",
-        type=str2bool,
-        default=False,
-    )
-    parser.add_argument(
-        "--num_worker_count",
-        type=int,
-        default=1,
-        help="The number of machines on PAI.",
-    )
-    parser.add_argument(
-        "--access_key_id",
-        type=str,
-        default=None,
-        help="The username for oss.",
-    )
-    parser.add_argument(
-        "--access_key_secret",
-        type=str,
-        default=None,
-        help="The password for oss.",
-    )
-    parser.add_argument(
-        "--endpoint",
-        type=str,
-        default=None,
-        help="The endpoint for oss.",
-    )
-    parser.add_argument(
-        "--bucket_name",
-        type=str,
-        default=None,
-        help="The bucket name for oss.",
-    )
-    parser.add_argument(
-        "--oss_bucket",
-        default=None,
-        help="oss bucket.",
-    )
-
-    # task related
-    parser.add_argument("--task_name", help="for different task")
-
-    return parser
+    main(**kwargs)
 
 
-if __name__ == '__main__':
-    parser = get_parser()
-    args = parser.parse_args()
+def main(**kwargs):
+    print(kwargs)
+    # set random seed
+    tables.print()
+    set_all_random_seed(kwargs.get("seed", 0))
+    torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
+    torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
+    torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
+    
+    local_rank = int(os.environ.get('LOCAL_RANK', 0))
+    # Check if we are using DDP or FSDP
+    use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
+    use_fsdp = kwargs.get("use_fsdp", None)
+    if use_ddp or use_fsdp:
+        dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
+        torch.cuda.set_device(local_rank)
+    
+    # save config.yaml
+    if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
+        os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
+        yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
+        OmegaConf.save(config=kwargs, f=yaml_file)
+        logging.info("config.yaml is saved to: %s", yaml_file)
 
-    # ddp init
-    args.distributed = args.dist_world_size > 1
-    distributed_option = build_distributed(args)
-    if not distributed_option.distributed or distributed_option.dist_rank == 0:
-        logging.basicConfig(
-            level="INFO",
-            format=f"[{os.uname()[1].split('.')[0]}]"
-                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-        )
+    tokenizer = kwargs.get("tokenizer", None)
+    if tokenizer is not None:
+        tokenizer_class = tables.tokenizer_classes.get(tokenizer)
+        tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
+        kwargs["tokenizer"] = tokenizer
+    
+    # build frontend if frontend is none None
+    frontend = kwargs.get("frontend", None)
+    if frontend is not None:
+        frontend_class = tables.frontend_classes.get(frontend)
+        frontend = frontend_class(**kwargs["frontend_conf"])
+        kwargs["frontend"] = frontend
+        kwargs["input_size"] = frontend.output_size()
+
+
+    # build model
+    model_class = tables.model_classes.get(kwargs["model"])
+    model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+
+
+
+    # init_param
+    init_param = kwargs.get("init_param", None)
+    if init_param is not None:
+        if not isinstance(init_param, (list, tuple)):
+            init_param = (init_param,)
+        logging.info("init_param is not None: %s", init_param)
+        for p in init_param:
+            logging.info(f"Loading pretrained params from {p}")
+            load_pretrained_model(
+                model=model,
+                path=p,
+                ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
+                oss_bucket=kwargs.get("oss_bucket", None),
+                scope_map=kwargs.get("scope_map", None),
+                excludes=kwargs.get("excludes", None),
+            )
     else:
-        logging.basicConfig(
-            level="ERROR",
-            format=f"[{os.uname()[1].split('.')[0]}]"
-                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-        )
-    logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
-                                                                   distributed_option.dist_rank,
-                                                                   distributed_option.local_rank))
+        initialize(model, kwargs.get("init", "kaiming_normal"))
 
-    # prepare files for dataloader
-    prepare_data(args, distributed_option)
 
-    set_all_random_seed(args.seed)
-    torch.backends.cudnn.enabled = args.cudnn_enabled
-    torch.backends.cudnn.benchmark = args.cudnn_benchmark
-    torch.backends.cudnn.deterministic = args.cudnn_deterministic
+    # freeze_param
+    freeze_param = kwargs.get("freeze_param", None)
+    if freeze_param is not None:
+        freeze_param = eval(freeze_param)
+        if isinstance(freeze_param, Sequence):
+            freeze_param = (freeze_param,)
+        logging.info("freeze_param is not None: %s", freeze_param)
+        for t in freeze_param:
+            for k, p in model.named_parameters():
+                if k.startswith(t + ".") or k == t:
+                    logging.info(f"Setting {k}.requires_grad = False")
+                    p.requires_grad = False
+    
+
+    if use_ddp:
+        model = model.cuda(local_rank)
+        model = DDP(model, device_ids=[local_rank],
+                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
+    elif use_fsdp:
+        model = FSDP(model).cuda(local_rank)
+    else:
+        model = model.to(device=kwargs.get("device", "cuda"))
+        
+        
+    # optim
+    optim = kwargs.get("optim", "adam")
+    assert optim in optim_classes
+    optim_class = optim_classes.get(optim)
+    optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
+    
+    # scheduler
+    scheduler = kwargs.get("scheduler", "warmuplr")
+    assert scheduler in scheduler_classes
+    scheduler_class = scheduler_classes.get(scheduler)
+    scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
+
+
+    # dataset
+    dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
+    dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
+    dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
+
+    # dataloader
+    batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
+    batch_sampler_val = None
+    if batch_sampler is not None:
+        batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
+        batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
+        batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
+    dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
+                                                collate_fn=dataset_tr.collator,
+                                                batch_sampler=batch_sampler,
+                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
+                                                pin_memory=True)
+    
+    dataloader_val = torch.utils.data.DataLoader(dataset_val,
+                                                collate_fn=dataset_val.collator,
+                                                batch_sampler=batch_sampler_val,
+                                                num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
+                                                pin_memory=True)
+    trainer = Trainer(
+        model=model,
+        optim=optim,
+        scheduler=scheduler,
+        dataloader_train=dataloader_tr,
+        dataloader_val=dataloader_val,
+        local_rank=local_rank,
+        use_ddp=use_ddp,
+        use_fsdp=use_fsdp,
+        output_dir=kwargs.get("output_dir", "./exp"),
+        resume=kwargs.get("resume", True),
+        **kwargs.get("train_conf"),
+    )
+    trainer.run()
+    
+    if use_ddp or use_fsdp:
+        torch.distributed.destroy_process_group()
+
+    
+
+if __name__ == "__main__":
+    main_hydra()
\ No newline at end of file

--
Gitblit v1.9.1