From 36c43d4c9f3ae98f026889b2f5f9726826a208d8 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 20 七月 2023 18:33:54 +0800
Subject: [PATCH] add lora finetune code
---
funasr/modules/attention.py | 16 +++++++++++++---
1 files changed, 13 insertions(+), 3 deletions(-)
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index f01e340..ab59493 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -338,7 +338,10 @@
else:
self.linear_out = nn.Linear(n_feat, n_feat)
lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
- self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+ if lora_qkv_list == [False, False, False]:
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ else:
+ self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
else:
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
@@ -562,11 +565,18 @@
if lora_list is not None:
if "q" in lora_list:
self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
lora_kv_list = ["k" in lora_list, "v" in lora_list]
- self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
- r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+ if lora_kv_list == [False, False]:
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ else:
+ self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
+ r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
if "o" in lora_list:
self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
else:
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
--
Gitblit v1.9.1