From 0063e4356a8e8e7bb984a6d847afbb11d5fbaa4a Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期五, 21 七月 2023 15:07:34 +0800
Subject: [PATCH] Merge pull request #768 from alibaba-damo-academy/dev_lhn

---
 funasr/modules/attention.py |   44 +++++++++++++++++++++++++++++++++++---------
 1 files changed, 35 insertions(+), 9 deletions(-)

diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index fcb3ed4..ab59493 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -15,6 +15,7 @@
 
 import torch.nn.functional as F
 from funasr.modules.nets_utils import make_pad_mask
+import funasr.modules.lora.layers as lora
 
 class MultiHeadedAttention(nn.Module):
     """Multi-Head Attention layer.
@@ -321,7 +322,7 @@
 
     """
 
-    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
         """Construct an MultiHeadedAttention object."""
         super(MultiHeadedAttentionSANM, self).__init__()
         assert n_feat % n_head == 0
@@ -331,8 +332,19 @@
         # self.linear_q = nn.Linear(n_feat, n_feat)
         # self.linear_k = nn.Linear(n_feat, n_feat)
         # self.linear_v = nn.Linear(n_feat, n_feat)
-        self.linear_out = nn.Linear(n_feat, n_feat)
-        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+        if lora_list is not None:
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
+            lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
+            if lora_qkv_list == [False, False, False]:
+                self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+            else:
+                self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+        else:
+            self.linear_out = nn.Linear(n_feat, n_feat)
+            self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
         self.attn = None
         self.dropout = nn.Dropout(p=dropout_rate)
 
@@ -543,18 +555,32 @@
 
     """
 
-    def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
+    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
         """Construct an MultiHeadedAttention object."""
         super(MultiHeadedAttentionCrossAtt, self).__init__()
         assert n_feat % n_head == 0
         # We assume d_v always equals d_k
         self.d_k = n_feat // n_head
         self.h = n_head
-        self.linear_q = nn.Linear(n_feat, n_feat)
-        # self.linear_k = nn.Linear(n_feat, n_feat)
-        # self.linear_v = nn.Linear(n_feat, n_feat)
-        self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
-        self.linear_out = nn.Linear(n_feat, n_feat)
+        if lora_list is not None:
+            if "q" in lora_list:
+                self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_q = nn.Linear(n_feat, n_feat)
+            lora_kv_list = ["k" in lora_list, "v" in lora_list]
+            if lora_kv_list == [False, False]:
+                self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            else:
+                self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
+                                      r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
+        else:
+            self.linear_q = nn.Linear(n_feat, n_feat)
+            self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            self.linear_out = nn.Linear(n_feat, n_feat)
         self.attn = None
         self.dropout = nn.Dropout(p=dropout_rate)
 

--
Gitblit v1.9.1