From 7ac54b3c97491ee9ac8a8ebbb7033240864f805a Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期二, 18 七月 2023 19:26:01 +0800
Subject: [PATCH] add lora finetune code
---
funasr/modules/attention.py | 34 +++++++++++++++++++++++++---------
funasr/models/encoder/sanm_encoder.py | 12 ++++++++++++
funasr/models/decoder/sanm_decoder.py | 6 +++++-
3 files changed, 42 insertions(+), 10 deletions(-)
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index d83f89f..c12e098 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -833,6 +833,10 @@
att_layer_num: int = 6,
kernel_size: int = 21,
sanm_shfit: int = 0,
+ lora_list: List[str] = None,
+ lora_rank: int = 8,
+ lora_alpha: int = 16,
+ lora_dropout: float = 0.1,
tf2torch_tensor_name_prefix_torch: str = "decoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
):
@@ -885,7 +889,7 @@
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate
+ attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 45163df..9e27d4a 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -146,6 +146,10 @@
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
+ lora_list: List[str] = None,
+ lora_rank: int = 8,
+ lora_alpha: int = 16,
+ lora_dropout: float = 0.1,
selfattention_layer_type: str = "sanm",
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
@@ -229,6 +233,10 @@
attention_dropout_rate,
kernel_size,
sanm_shfit,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
)
encoder_selfattn_layer_args = (
@@ -238,6 +246,10 @@
attention_dropout_rate,
kernel_size,
sanm_shfit,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
)
self.encoders0 = repeat(
1,
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index fcb3ed4..f01e340 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -15,6 +15,7 @@
import torch.nn.functional as F
from funasr.modules.nets_utils import make_pad_mask
+import funasr.modules.lora.layers as lora
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@@ -321,7 +322,7 @@
"""
- def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+ def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttentionSANM, self).__init__()
assert n_feat % n_head == 0
@@ -331,8 +332,16 @@
# self.linear_q = nn.Linear(n_feat, n_feat)
# self.linear_k = nn.Linear(n_feat, n_feat)
# self.linear_v = nn.Linear(n_feat, n_feat)
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ if lora_list is not None:
+ if "o" in lora_list:
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
+ self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
@@ -543,18 +552,25 @@
"""
- def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
+ def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttentionCrossAtt, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
- self.linear_q = nn.Linear(n_feat, n_feat)
- # self.linear_k = nn.Linear(n_feat, n_feat)
- # self.linear_v = nn.Linear(n_feat, n_feat)
- self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
- self.linear_out = nn.Linear(n_feat, n_feat)
+ if lora_list is not None:
+ if "q" in lora_list:
+ self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ lora_kv_list = ["k" in lora_list, "v" in lora_list]
+ self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
+ r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+ if "o" in lora_list:
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
--
Gitblit v1.9.1