From 3d9f094e9652d4b84894c6fd4eae39a4a753b0f0 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 16 五月 2023 23:48:00 +0800
Subject: [PATCH] train

---
 funasr/bin/train.py |   78 ++++++++++++++++++++++++---------------
 1 files changed, 48 insertions(+), 30 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 3e3f598..53e5bde 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -16,6 +16,7 @@
 from funasr.build_utils.build_scheduler import build_scheduler
 from funasr.build_utils.build_trainer import build_trainer
 from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
@@ -23,7 +24,6 @@
 from funasr.utils.prepare_data import prepare_data
 from funasr.utils.types import int_or_none
 from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
 
@@ -161,6 +161,7 @@
     )
     parser.add_argument(
         "--patience",
+        type=int_or_none,
         default=None,
         help="Number of epochs to wait without improvement "
              "before stopping the training",
@@ -260,6 +261,12 @@
              "training phase. If None is given, it is decided according the number "
              "of training samples automatically .",
     )
+    parser.add_argument(
+        "--use_tensorboard",
+        type=str2bool,
+        default=True,
+        help="Enable tensorboard logging",
+    )
 
     # pretrained model related
     parser.add_argument(
@@ -310,47 +317,41 @@
         help=f"The keyword arguments for dataset",
     )
     parser.add_argument(
-        "--train_data_file",
+        "--data_dir",
         type=str,
         default=None,
-        help="train_list for large dataset",
+        help="root path of data",
     )
     parser.add_argument(
-        "--valid_data_file",
+        "--train_set",
         type=str,
+        default="train",
+        help="train dataset",
+    )
+    parser.add_argument(
+        "--valid_set",
+        type=str,
+        default="validation",
+        help="dev dataset",
+    )
+    parser.add_argument(
+        "--speed_perturb",
+        type=float,
+        nargs="+",
         default=None,
-        help="valid_list for large dataset",
-    )
-    parser.add_argument(
-        "--train_data_path_and_name_and_type",
-        type=str2triple_str,
-        action="append",
-        default=[],
-        help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ",
-    )
-    parser.add_argument(
-        "--valid_data_path_and_name_and_type",
-        type=str2triple_str,
-        action="append",
-        default=[],
-    )
-    parser.add_argument(
-        "--train_shape_file",
-        type=str,
-        action="append",
-        default=[],
-    )
-    parser.add_argument(
-        "--valid_shape_file",
-        type=str,
-        action="append",
-        default=[],
+        help="speed perturb",
     )
     parser.add_argument(
         "--use_preprocessor",
         type=str2bool,
         default=True,
         help="Apply preprocessing to data or not",
+    )
+    parser.add_argument(
+        "--embed_path",
+        type=str,
+        default=None,
+        help="for model which requires embeds",
     )
 
     # optimization related
@@ -514,6 +515,10 @@
     prepare_data(args, distributed_option)
 
     model = build_model(args)
+    model = model.to(
+        dtype=getattr(torch, args.train_dtype),
+        device="cuda" if args.ngpu > 0 else "cpu",
+    )
     optimizers = build_optimizer(args, model=model)
     schedulers = build_scheduler(args, optimizers)
 
@@ -521,6 +526,7 @@
                                                                    distributed_option.dist_rank,
                                                                    distributed_option.local_rank))
     logging.info(pytorch_cudnn_version())
+    logging.info("Args: {}".format(args))
     logging.info(model_summary(model))
     logging.info("Optimizer: {}".format(optimizers))
     logging.info("Scheduler: {}".format(schedulers))
@@ -537,6 +543,18 @@
             else:
                 yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
 
+    for p in args.init_param:
+        logging.info(f"Loading pretrained params from {p}")
+        load_pretrained_model(
+            model=model,
+            init_param=p,
+            ignore_init_mismatch=args.ignore_init_mismatch,
+            map_location=f"cuda:{torch.cuda.current_device()}"
+            if args.ngpu > 0
+            else "cpu",
+            oss_bucket=args.oss_bucket,
+        )
+
     # dataloader for training/validation
     train_dataloader, valid_dataloader = build_dataloader(args)
 

--
Gitblit v1.9.1