From 54931dd4e1a099d7d6f144c4e12e5453deb3aa26 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 28 六月 2023 10:41:57 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main

---
 funasr/bin/train.py |  288 +++++++++++++++++++++++++++++++++++++++++++++++++++-----
 1 files changed, 259 insertions(+), 29 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
old mode 100644
new mode 100755
index dbfebd7..1dc3fb5
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,18 +1,37 @@
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import argparse
 import logging
 import os
 import sys
+from io import BytesIO
 
 import torch
 
+from funasr.build_utils.build_args import build_args
+from funasr.build_utils.build_dataloader import build_dataloader
+from funasr.build_utils.build_distributed import build_distributed
+from funasr.build_utils.build_model import build_model
+from funasr.build_utils.build_optimizer import build_optimizer
+from funasr.build_utils.build_scheduler import build_scheduler
+from funasr.build_utils.build_trainer import build_trainer
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.load_pretrained_model import load_pretrained_model
+from funasr.torch_utils.model_summary import model_summary
+from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils import config_argparse
-from funasr.utils.build_distributed import build_distributed
+from funasr.utils.nested_dict_action import NestedDictAction
 from funasr.utils.prepare_data import prepare_data
+from funasr.utils.types import int_or_none
 from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
 
 
 def get_parser():
-    parser = config_argparse.ArgumentParser(
+    parser = argparse.ArgumentParser(
         description="FunASR Common Training Parser",
     )
 
@@ -25,6 +44,7 @@
         help="The number of gpus. 0 indicates CPU mode",
     )
     parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
 
     # ddp related
     parser.add_argument(
@@ -42,18 +62,51 @@
     )
     parser.add_argument(
         "--dist_world_size",
-        default=None,
+        type=int,
+        default=1,
         help="number of nodes for distributed training",
     )
     parser.add_argument(
         "--dist_rank",
+        type=int,
         default=None,
         help="node rank for distributed training",
     )
     parser.add_argument(
         "--local_rank",
+        type=int,
         default=None,
         help="local rank for distributed training",
+    )
+    parser.add_argument(
+        "--dist_master_addr",
+        default=None,
+        type=str_or_none,
+        help="The master address for distributed training. "
+             "This value is used when dist_init_method == 'env://'",
+    )
+    parser.add_argument(
+        "--dist_master_port",
+        default=None,
+        type=int_or_none,
+        help="The master port for distributed training"
+             "This value is used when dist_init_method == 'env://'",
+    )
+    parser.add_argument(
+        "--dist_launcher",
+        default=None,
+        type=str_or_none,
+        choices=["slurm", "mpi", None],
+        help="The launcher type for distributed training",
+    )
+    parser.add_argument(
+        "--multiprocessing_distributed",
+        default=True,
+        type=str2bool,
+        help="Use multi-processing distributed training to launch "
+             "N processes per node, which has N GPUs. This is the "
+             "fastest way to use PyTorch for either single node or "
+             "multi node data parallel training",
     )
     parser.add_argument(
         "--unused_parameters",
@@ -61,6 +114,12 @@
         default=False,
         help="Whether to use the find_unused_parameters in "
              "torch.nn.parallel.DistributedDataParallel ",
+    )
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
     )
 
     # cudnn related
@@ -104,6 +163,7 @@
     )
     parser.add_argument(
         "--patience",
+        type=int_or_none,
         default=None,
         help="Number of epochs to wait without improvement "
              "before stopping the training",
@@ -185,6 +245,12 @@
         help="Enable resuming if checkpoint is existing",
     )
     parser.add_argument(
+        "--train_dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type for training.",
+    )
+    parser.add_argument(
         "--use_amp",
         type=str2bool,
         default=False,
@@ -197,13 +263,19 @@
              "training phase. If None is given, it is decided according the number "
              "of training samples automatically .",
     )
+    parser.add_argument(
+        "--use_tensorboard",
+        type=str2bool,
+        default=True,
+        help="Enable tensorboard logging",
+    )
 
     # pretrained model related
     parser.add_argument(
         "--init_param",
         type=str,
+        action="append",
         default=[],
-        nargs="*",
         help="Specify the file path used for initialization of parameters. "
              "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
              "where file_path is the model file path, "
@@ -229,7 +301,7 @@
         "--freeze_param",
         type=str,
         default=[],
-        nargs="*",
+        action="append",
         help="Freeze parameters",
     )
 
@@ -241,27 +313,122 @@
         help="whether to use dataloader for large dataset",
     )
     parser.add_argument(
-        "--train_data_file",
+        "--dataset_conf",
+        action=NestedDictAction,
+        default=dict(),
+        help=f"The keyword arguments for dataset",
+    )
+    parser.add_argument(
+        "--data_dir",
         type=str,
         default=None,
-        help="train_list for large dataset",
+        help="root path of data",
     )
     parser.add_argument(
-        "--valid_data_file",
+        "--train_set",
         type=str,
+        default="train",
+        help="train dataset",
+    )
+    parser.add_argument(
+        "--valid_set",
+        type=str,
+        default="validation",
+        help="dev dataset",
+    )
+    parser.add_argument(
+        "--data_file_names",
+        type=str,
+        default="wav.scp,text",
+        help="input data files",
+    )
+    parser.add_argument(
+        "--speed_perturb",
+        type=float,
+        nargs="+",
         default=None,
-        help="valid_list for large dataset",
+        help="speed perturb",
     )
     parser.add_argument(
-        "--train_data_path_and_name_and_type",
-        action="append",
-        default=[],
-        help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ",
+        "--use_preprocessor",
+        type=str2bool,
+        default=True,
+        help="Apply preprocessing to data or not",
+    )
+
+    # optimization related
+    parser.add_argument(
+        "--optim",
+        type=lambda x: x.lower(),
+        default="adam",
+        help="The optimizer type",
     )
     parser.add_argument(
-        "--valid_data_path_and_name_and_type",
-        action="append",
-        default=[],
+        "--optim_conf",
+        action=NestedDictAction,
+        default=dict(),
+        help="The keyword arguments for optimizer",
+    )
+    parser.add_argument(
+        "--scheduler",
+        type=lambda x: str_or_none(x.lower()),
+        default=None,
+        help="The lr scheduler type",
+    )
+    parser.add_argument(
+        "--scheduler_conf",
+        action=NestedDictAction,
+        default=dict(),
+        help="The keyword arguments for lr scheduler",
+    )
+
+    # most task related
+    parser.add_argument(
+        "--init",
+        type=lambda x: str_or_none(x.lower()),
+        default=None,
+        help="The initialization method",
+        choices=[
+            "chainer",
+            "xavier_uniform",
+            "xavier_normal",
+            "kaiming_uniform",
+            "kaiming_normal",
+            None,
+        ],
+    )
+    parser.add_argument(
+        "--token_list",
+        type=str_or_none,
+        default=None,
+        help="A text mapping int-id to token",
+    )
+    parser.add_argument(
+        "--token_type",
+        type=str,
+        default="bpe",
+        choices=["bpe", "char", "word"],
+        help="",
+    )
+    parser.add_argument(
+        "--bpemodel",
+        type=str_or_none,
+        default=None,
+        help="The model file fo sentencepiece",
+    )
+    parser.add_argument(
+        "--cleaner",
+        type=str_or_none,
+        choices=[None, "tacotron", "jaconv", "vietnamese"],
+        default=None,
+        help="Apply text cleaning",
+    )
+    parser.add_argument(
+        "--g2p",
+        type=str_or_none,
+        choices=g2p_choices,
+        default=None,
+        help="Specify g2p method if --token_type=phn",
     )
 
     # pai related
@@ -312,19 +479,27 @@
         help="oss bucket.",
     )
 
-    # task related
-    parser.add_argument("--task_name", help="for different task")
-
     return parser
 
 
 if __name__ == '__main__':
     parser = get_parser()
-    args = parser.parse_args()
+    args, extra_task_params = parser.parse_known_args()
+    if extra_task_params:
+        args = build_args(args, parser, extra_task_params)
+
+    # set random seed
+    set_all_random_seed(args.seed)
+    torch.backends.cudnn.enabled = args.cudnn_enabled
+    torch.backends.cudnn.benchmark = args.cudnn_benchmark
+    torch.backends.cudnn.deterministic = args.cudnn_deterministic
 
     # ddp init
-    args.distributed = args.dist_world_size > 1
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+    args.distributed = args.ngpu > 1 or args.dist_world_size > 1
     distributed_option = build_distributed(args)
+
+    # for logging
     if not distributed_option.distributed or distributed_option.dist_rank == 0:
         logging.basicConfig(
             level="INFO",
@@ -337,14 +512,69 @@
             format=f"[{os.uname()[1].split('.')[0]}]"
                    f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
         )
-    logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
-                                                                   distributed_option.dist_rank,
-                                                                   distributed_option.local_rank))
 
     # prepare files for dataloader
     prepare_data(args, distributed_option)
 
-    set_all_random_seed(args.seed)
-    torch.backends.cudnn.enabled = args.cudnn_enabled
-    torch.backends.cudnn.benchmark = args.cudnn_benchmark
-    torch.backends.cudnn.deterministic = args.cudnn_deterministic
+    model = build_model(args)
+    model = model.to(
+        dtype=getattr(torch, args.train_dtype),
+        device="cuda" if args.ngpu > 0 else "cpu",
+    )
+    for t in args.freeze_param:
+        for k, p in model.named_parameters():
+            if k.startswith(t + ".") or k == t:
+                logging.info(f"Setting {k}.requires_grad = False")
+                p.requires_grad = False
+
+    optimizers = build_optimizer(args, model=model)
+    schedulers = build_scheduler(args, optimizers)
+
+    logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
+                                                                   distributed_option.dist_rank,
+                                                                   distributed_option.local_rank))
+    logging.info(pytorch_cudnn_version())
+    logging.info("Args: {}".format(args))
+    logging.info(model_summary(model))
+    logging.info("Optimizer: {}".format(optimizers))
+    logging.info("Scheduler: {}".format(schedulers))
+
+    # dump args to config.yaml
+    if not distributed_option.distributed or distributed_option.dist_rank == 0:
+        os.makedirs(args.output_dir, exist_ok=True)
+        with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
+            logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
+            if args.use_pai:
+                buffer = BytesIO()
+                torch.save({"config": vars(args)}, buffer)
+                args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
+            else:
+                yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
+
+    for p in args.init_param:
+        logging.info(f"Loading pretrained params from {p}")
+        load_pretrained_model(
+            model=model,
+            init_param=p,
+            ignore_init_mismatch=args.ignore_init_mismatch,
+            map_location=f"cuda:{torch.cuda.current_device()}"
+            if args.ngpu > 0
+            else "cpu",
+            oss_bucket=args.oss_bucket,
+        )
+
+    # dataloader for training/validation
+    train_dataloader, valid_dataloader = build_dataloader(args)
+
+    # Trainer, including model, optimizers, etc.
+    trainer = build_trainer(
+        args=args,
+        model=model,
+        optimizers=optimizers,
+        schedulers=schedulers,
+        train_dataloader=train_dataloader,
+        valid_dataloader=valid_dataloader,
+        distributed_option=distributed_option
+    )
+
+    trainer.run()

--
Gitblit v1.9.1