From 36c43d4c9f3ae98f026889b2f5f9726826a208d8 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 20 七月 2023 18:33:54 +0800
Subject: [PATCH] add lora finetune code

---
 funasr/modules/lora/layers.py |  248 +++++++++++++++++++++++++------------------------
 1 files changed, 126 insertions(+), 122 deletions(-)

diff --git a/funasr/modules/lora/layers.py b/funasr/modules/lora/layers.py
index 9115b28..76f046c 100644
--- a/funasr/modules/lora/layers.py
+++ b/funasr/modules/lora/layers.py
@@ -11,9 +11,9 @@
 
 class LoRALayer():
     def __init__(
-        self, 
-        r: int, 
-        lora_alpha: int, 
+        self,
+        r: int,
+        lora_alpha: int,
         lora_dropout: float,
         merge_weights: bool,
     ):
@@ -61,40 +61,42 @@
 
     def train(self, mode: bool = True):
         nn.Embedding.train(self, mode)
-        if mode:
-            if self.merge_weights and self.merged:
-                # Make sure that the weights are not merged
-                if self.r > 0:
-                    self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
-                self.merged = False
-        else:
-            if self.merge_weights and not self.merged:
-                # Merge the weights and mark it
-                if self.r > 0:
-                    self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
-                self.merged = True
-        
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            if self.r > 0:
+                self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
+            self.merged = False
+
+    def eval(self):
+        nn.Linear.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            if self.r > 0:
+                self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
+            self.merged = True
+
     def forward(self, x: torch.Tensor):
         if self.r > 0 and not self.merged:
             result = nn.Embedding.forward(self, x)
-            after_A = F.embedding(
-                x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
-                self.norm_type, self.scale_grad_by_freq, self.sparse
-            )
-            result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
+            if self.r > 0:
+                after_A = F.embedding(
+                    x, self.lora_A.T, self.padding_idx, self.max_norm,
+                    self.norm_type, self.scale_grad_by_freq, self.sparse
+                )
+                result += (after_A @ self.lora_B.T) * self.scaling
             return result
         else:
             return nn.Embedding.forward(self, x)
-            
+
 
 class Linear(nn.Linear, LoRALayer):
     # LoRA implemented in a dense layer
     def __init__(
-        self, 
-        in_features: int, 
-        out_features: int, 
-        r: int = 0, 
-        lora_alpha: int = 1, 
+        self,
+        in_features: int,
+        out_features: int,
+        r: int = 0,
+        lora_alpha: int = 1,
         lora_dropout: float = 0.,
         fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
         merge_weights: bool = True,
@@ -114,7 +116,7 @@
             self.weight.requires_grad = False
         self.reset_parameters()
         if fan_in_fan_out:
-            self.weight.data = self.weight.data.transpose(0, 1)
+            self.weight.data = self.weight.data.T
 
     def reset_parameters(self):
         nn.Linear.reset_parameters(self)
@@ -125,27 +127,31 @@
 
     def train(self, mode: bool = True):
         def T(w):
-            return w.transpose(0, 1) if self.fan_in_fan_out else w
+            return w.T if self.fan_in_fan_out else w
         nn.Linear.train(self, mode)
-        if mode:
-            if self.merge_weights and self.merged:
-                # Make sure that the weights are not merged
-                if self.r > 0:
-                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
-                self.merged = False
-        else:
-            if self.merge_weights and not self.merged:
-                # Merge the weights and mark it
-                if self.r > 0:
-                    self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
-                self.merged = True       
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            if self.r > 0:
+                self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+            self.merged = False
+
+    def eval(self):
+        def T(w):
+            return w.T if self.fan_in_fan_out else w
+        nn.Linear.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            if self.r > 0:
+                self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+            self.merged = True
 
     def forward(self, x: torch.Tensor):
         def T(w):
-            return w.transpose(0, 1) if self.fan_in_fan_out else w
+            return w.T if self.fan_in_fan_out else w
         if self.r > 0 and not self.merged:
-            result = F.linear(x, T(self.weight), bias=self.bias)            
-            result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
+            result = F.linear(x, T(self.weight), bias=self.bias)
+            if self.r > 0:
+                result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
             return result
         else:
             return F.linear(x, T(self.weight), bias=self.bias)
@@ -154,11 +160,11 @@
 class MergedLinear(nn.Linear, LoRALayer):
     # LoRA implemented in a dense layer
     def __init__(
-        self, 
-        in_features: int, 
-        out_features: int, 
-        r: int = 0, 
-        lora_alpha: int = 1, 
+        self,
+        in_features: int,
+        out_features: int,
+        r: int = 0,
+        lora_alpha: int = 1,
         lora_dropout: float = 0.,
         enable_lora: List[bool] = [False],
         fan_in_fan_out: bool = False,
@@ -190,7 +196,7 @@
             self.lora_ind = self.lora_ind.view(-1)
         self.reset_parameters()
         if fan_in_fan_out:
-            self.weight.data = self.weight.data.transpose(0, 1)
+            self.weight.data = self.weight.data.T
 
     def reset_parameters(self):
         nn.Linear.reset_parameters(self)
@@ -209,34 +215,37 @@
 
     def train(self, mode: bool = True):
         def T(w):
-            return w.transpose(0, 1) if self.fan_in_fan_out else w
+            return w.T if self.fan_in_fan_out else w
         nn.Linear.train(self, mode)
-        if mode:
-            if self.merge_weights and self.merged:
-                # Make sure that the weights are not merged
-                if self.r > 0 and any(self.enable_lora):
-                    delta_w = F.conv1d(
-                        self.lora_A.data.unsqueeze(0), 
-                        self.lora_B.data.unsqueeze(-1), 
-                        groups=sum(self.enable_lora)
-                    ).squeeze(0)
-                    self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
-                self.merged = False
-        else:
-            if self.merge_weights and not self.merged:
-                # Merge the weights and mark it
-                if self.r > 0 and any(self.enable_lora):
-                    delta_w = F.conv1d(
-                        self.lora_A.data.unsqueeze(0), 
-                        self.lora_B.data.unsqueeze(-1), 
-                        groups=sum(self.enable_lora)
-                    ).squeeze(0)
-                    self.weight.data += self.zero_pad(T(delta_w * self.scaling))
-                self.merged = True        
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            if self.r > 0 and any(self.enable_lora):
+                delta_w = F.conv1d(
+                    self.lora_A.data.unsqueeze(0),
+                    self.lora_B.data.unsqueeze(-1),
+                    groups=sum(self.enable_lora)
+                ).squeeze(0)
+                self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
+            self.merged = False
+
+    def eval(self):
+        def T(w):
+            return w.T if self.fan_in_fan_out else w
+        nn.Linear.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            if self.r > 0 and any(self.enable_lora):
+                delta_w = F.conv1d(
+                    self.lora_A.data.unsqueeze(0),
+                    self.lora_B.data.unsqueeze(-1),
+                    groups=sum(self.enable_lora)
+                ).squeeze(0)
+                self.weight.data += self.zero_pad(T(delta_w * self.scaling))
+            self.merged = True
 
     def forward(self, x: torch.Tensor):
         def T(w):
-            return w.transpose(0, 1) if self.fan_in_fan_out else w
+            return w.T if self.fan_in_fan_out else w
         if self.merged:
             return F.linear(x, T(self.weight), bias=self.bias)
         else:
@@ -244,76 +253,71 @@
             if self.r > 0:
                 after_A = F.linear(self.lora_dropout(x), self.lora_A)
                 after_B = F.conv1d(
-                    after_A.transpose(-2, -1), 
-                    self.lora_B.unsqueeze(-1), 
+                    after_A.transpose(-2, -1),
+                    self.lora_B.unsqueeze(-1),
                     groups=sum(self.enable_lora)
                 ).transpose(-2, -1)
                 result += self.zero_pad(after_B) * self.scaling
             return result
-        
 
-class ConvLoRA(nn.Module, LoRALayer):
-    def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
-        super(ConvLoRA, self).__init__()
-        self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
-        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
-        assert isinstance(kernel_size, int)
+
+class Conv2d(nn.Conv2d, LoRALayer):
+    # LoRA implemented in a dense layer
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        r: int = 0,
+        lora_alpha: int = 1,
+        lora_dropout: float = 0.,
+        merge_weights: bool = True,
+        **kwargs
+    ):
+        nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
+        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+                           merge_weights=merge_weights)
+        assert type(kernel_size) is int
         # Actual trainable parameters
         if r > 0:
             self.lora_A = nn.Parameter(
-                self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
+                self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
             )
             self.lora_B = nn.Parameter(
-              self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
+                self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
             )
             self.scaling = self.lora_alpha / self.r
             # Freezing the pre-trained weight matrix
-            self.conv.weight.requires_grad = False
+            self.weight.requires_grad = False
         self.reset_parameters()
-        self.merged = False
 
     def reset_parameters(self):
-        self.conv.reset_parameters()
+        nn.Conv2d.reset_parameters(self)
         if hasattr(self, 'lora_A'):
             # initialize A the same way as the default for nn.Linear and B to zero
             nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
             nn.init.zeros_(self.lora_B)
 
-    def train(self, mode=True):
-        super(ConvLoRA, self).train(mode)
-        if mode:
-            if self.merge_weights and self.merged:
-                if self.r > 0:
-                    # Make sure that the weights are not merged
-                    self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
-                self.merged = False
-        else:
-            if self.merge_weights and not self.merged:
-                if self.r > 0:
-                    # Merge the weights and mark it
-                    self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
-                self.merged = True
+    def train(self, mode: bool = True):
+        nn.Conv2d.train(self, mode)
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+            self.merged = False
 
-    def forward(self, x):
+    def eval(self):
+        nn.Conv2d.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+            self.merged = True
+
+    def forward(self, x: torch.Tensor):
         if self.r > 0 and not self.merged:
-            return self.conv._conv_forward(
-                x, 
-                self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
-                self.conv.bias
+            return F.conv2d(
+                x,
+                self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
+                self.bias, self.stride, self.padding, self.dilation, self.groups
             )
-        return self.conv(x)
-
-class Conv2d(ConvLoRA):
-    def __init__(self, *args, **kwargs):
-        super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)
-
-class Conv1d(ConvLoRA):
-    def __init__(self, *args, **kwargs):
-        super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)
-
-# Can Extend to other ones like this
-
-class Conv3d(ConvLoRA):
-    def __init__(self, *args, **kwargs):
-        super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)
+        return nn.Conv2d.forward(self, x)
 

--
Gitblit v1.9.1