From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/bin/train.py |  105 ++++++++++++++++++++++++++++++++++++----------------
 1 files changed, 73 insertions(+), 32 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 3e3f598..f5d10c4 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -1,4 +1,6 @@
 #!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import argparse
 import logging
@@ -16,6 +18,7 @@
 from funasr.build_utils.build_scheduler import build_scheduler
 from funasr.build_utils.build_trainer import build_trainer
 from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
@@ -23,9 +26,9 @@
 from funasr.utils.prepare_data import prepare_data
 from funasr.utils.types import int_or_none
 from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+from funasr.modules.lora.utils import mark_only_lora_as_trainable
 
 
 def get_parser():
@@ -161,6 +164,7 @@
     )
     parser.add_argument(
         "--patience",
+        type=int_or_none,
         default=None,
         help="Number of epochs to wait without improvement "
              "before stopping the training",
@@ -260,13 +264,19 @@
              "training phase. If None is given, it is decided according the number "
              "of training samples automatically .",
     )
+    parser.add_argument(
+        "--use_tensorboard",
+        type=str2bool,
+        default=True,
+        help="Enable tensorboard logging",
+    )
 
     # pretrained model related
     parser.add_argument(
         "--init_param",
         type=str,
+        action="append",
         default=[],
-        nargs="*",
         help="Specify the file path used for initialization of parameters. "
              "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
              "where file_path is the model file path, "
@@ -292,7 +302,7 @@
         "--freeze_param",
         type=str,
         default=[],
-        nargs="*",
+        action="append",
         help="Freeze parameters",
     )
 
@@ -310,41 +320,35 @@
         help=f"The keyword arguments for dataset",
     )
     parser.add_argument(
-        "--train_data_file",
+        "--data_dir",
         type=str,
         default=None,
-        help="train_list for large dataset",
+        help="root path of data",
     )
     parser.add_argument(
-        "--valid_data_file",
+        "--train_set",
         type=str,
+        default="train",
+        help="train dataset",
+    )
+    parser.add_argument(
+        "--valid_set",
+        type=str,
+        default="validation",
+        help="dev dataset",
+    )
+    parser.add_argument(
+        "--data_file_names",
+        type=str,
+        default="wav.scp,text",
+        help="input data files",
+    )
+    parser.add_argument(
+        "--speed_perturb",
+        type=float,
+        nargs="+",
         default=None,
-        help="valid_list for large dataset",
-    )
-    parser.add_argument(
-        "--train_data_path_and_name_and_type",
-        type=str2triple_str,
-        action="append",
-        default=[],
-        help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ",
-    )
-    parser.add_argument(
-        "--valid_data_path_and_name_and_type",
-        type=str2triple_str,
-        action="append",
-        default=[],
-    )
-    parser.add_argument(
-        "--train_shape_file",
-        type=str,
-        action="append",
-        default=[],
-    )
-    parser.add_argument(
-        "--valid_shape_file",
-        type=str,
-        action="append",
-        default=[],
+        help="speed perturb",
     )
     parser.add_argument(
         "--use_preprocessor",
@@ -475,6 +479,18 @@
         default=None,
         help="oss bucket.",
     )
+    parser.add_argument(
+        "--enable_lora",
+        type=str2bool,
+        default=False,
+        help="Apply lora for finetuning.",
+    )
+    parser.add_argument(
+        "--lora_bias",
+        type=str,
+        default="none",
+        help="lora bias.",
+    )
 
     return parser
 
@@ -514,6 +530,18 @@
     prepare_data(args, distributed_option)
 
     model = build_model(args)
+    model = model.to(
+        dtype=getattr(torch, args.train_dtype),
+        device="cuda" if args.ngpu > 0 else "cpu",
+    )
+    if args.enable_lora:
+        mark_only_lora_as_trainable(model, args.lora_bias)
+    for t in args.freeze_param:
+        for k, p in model.named_parameters():
+            if k.startswith(t + ".") or k == t:
+                logging.info(f"Setting {k}.requires_grad = False")
+                p.requires_grad = False
+
     optimizers = build_optimizer(args, model=model)
     schedulers = build_scheduler(args, optimizers)
 
@@ -521,6 +549,7 @@
                                                                    distributed_option.dist_rank,
                                                                    distributed_option.local_rank))
     logging.info(pytorch_cudnn_version())
+    logging.info("Args: {}".format(args))
     logging.info(model_summary(model))
     logging.info("Optimizer: {}".format(optimizers))
     logging.info("Scheduler: {}".format(schedulers))
@@ -537,6 +566,18 @@
             else:
                 yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
 
+    for p in args.init_param:
+        logging.info(f"Loading pretrained params from {p}")
+        load_pretrained_model(
+            model=model,
+            init_param=p,
+            ignore_init_mismatch=args.ignore_init_mismatch,
+            map_location=f"cuda:{torch.cuda.current_device()}"
+            if args.ngpu > 0
+            else "cpu",
+            oss_bucket=args.oss_bucket,
+        )
+
     # dataloader for training/validation
     train_dataloader, valid_dataloader = build_dataloader(args)
 

--
Gitblit v1.9.1