From 16c41542451f399bdb716f1d7cad31cf52f6f8c3 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期二, 18 七月 2023 16:47:27 +0800
Subject: [PATCH] add lora finetune code
---
funasr/bin/train.py | 15 +++++++++++++++
1 files changed, 15 insertions(+), 0 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 1dc3fb5..c9c0b02 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -28,6 +28,7 @@
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+from funasr.modules.lora.utils import mark_only_lora_as_trainable
def get_parser():
@@ -478,6 +479,18 @@
default=None,
help="oss bucket.",
)
+ parser.add_argument(
+ "--enable_lora",
+ type=str2bool,
+ default=False,
+ help="Apply lora for finetuning.",
+ )
+ parser.add_argument(
+ "--lora_bias",
+ type=str,
+ default="none",
+ help="oss bucket.",
+ )
return parser
@@ -521,6 +534,8 @@
dtype=getattr(torch, args.train_dtype),
device="cuda" if args.ngpu > 0 else "cpu",
)
+ if args.enable_lora:
+ mark_only_lora_as_trainable(model, args.lora_bias)
for t in args.freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
--
Gitblit v1.9.1