From 178fa2cd1928d835d06e4a6fc32c9457e484e203 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 03 一月 2024 14:19:18 +0800
Subject: [PATCH] update SDK_advanced_guide_online
---
funasr/bin/train.py | 19 +++++++++++++++++--
1 files changed, 17 insertions(+), 2 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index f4fc0a7..6aebf8a 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -17,7 +17,7 @@
from funasr.build_utils.build_optimizer import build_optimizer
from funasr.build_utils.build_scheduler import build_scheduler
from funasr.build_utils.build_trainer import build_trainer
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
@@ -28,6 +28,7 @@
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+from funasr.modules.lora.utils import mark_only_lora_as_trainable
def get_parser():
@@ -301,7 +302,7 @@
"--freeze_param",
type=str,
default=[],
- nargs="*",
+ action="append",
help="Freeze parameters",
)
@@ -478,6 +479,18 @@
default=None,
help="oss bucket.",
)
+ parser.add_argument(
+ "--enable_lora",
+ type=str2bool,
+ default=False,
+ help="Apply lora for finetuning.",
+ )
+ parser.add_argument(
+ "--lora_bias",
+ type=str,
+ default="none",
+ help="lora bias.",
+ )
return parser
@@ -521,6 +534,8 @@
dtype=getattr(torch, args.train_dtype),
device="cuda" if args.ngpu > 0 else "cpu",
)
+ if args.enable_lora:
+ mark_only_lora_as_trainable(model, args.lora_bias)
for t in args.freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
--
Gitblit v1.9.1