| | |
| | | |
| | | class LoRALayer(): |
| | | def __init__( |
| | | self, |
| | | r: int, |
| | | lora_alpha: int, |
| | | self, |
| | | r: int, |
| | | lora_alpha: int, |
| | | lora_dropout: float, |
| | | merge_weights: bool, |
| | | ): |
| | |
| | | |
| | | def train(self, mode: bool = True): |
| | | nn.Embedding.train(self, mode) |
| | | if mode: |
| | | if self.merge_weights and self.merged: |
| | | # Make sure that the weights are not merged |
| | | if self.r > 0: |
| | | self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling |
| | | self.merged = False |
| | | else: |
| | | if self.merge_weights and not self.merged: |
| | | # Merge the weights and mark it |
| | | if self.r > 0: |
| | | self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling |
| | | self.merged = True |
| | | |
| | | if self.merge_weights and self.merged: |
| | | # Make sure that the weights are not merged |
| | | if self.r > 0: |
| | | self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling |
| | | self.merged = False |
| | | |
| | | def eval(self): |
| | | nn.Linear.eval(self) |
| | | if self.merge_weights and not self.merged: |
| | | # Merge the weights and mark it |
| | | if self.r > 0: |
| | | self.weight.data += (self.lora_B @ self.lora_A) * self.scaling |
| | | self.merged = True |
| | | |
| | | def forward(self, x: torch.Tensor): |
| | | if self.r > 0 and not self.merged: |
| | | result = nn.Embedding.forward(self, x) |
| | | after_A = F.embedding( |
| | | x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm, |
| | | self.norm_type, self.scale_grad_by_freq, self.sparse |
| | | ) |
| | | result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling |
| | | if self.r > 0: |
| | | after_A = F.embedding( |
| | | x, self.lora_A.T, self.padding_idx, self.max_norm, |
| | | self.norm_type, self.scale_grad_by_freq, self.sparse |
| | | ) |
| | | result += (after_A @ self.lora_B.T) * self.scaling |
| | | return result |
| | | else: |
| | | return nn.Embedding.forward(self, x) |
| | | |
| | | |
| | | |
| | | class Linear(nn.Linear, LoRALayer): |
| | | # LoRA implemented in a dense layer |
| | | def __init__( |
| | | self, |
| | | in_features: int, |
| | | out_features: int, |
| | | r: int = 0, |
| | | lora_alpha: int = 1, |
| | | self, |
| | | in_features: int, |
| | | out_features: int, |
| | | r: int = 0, |
| | | lora_alpha: int = 1, |
| | | lora_dropout: float = 0., |
| | | fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) |
| | | merge_weights: bool = True, |
| | |
| | | self.weight.requires_grad = False |
| | | self.reset_parameters() |
| | | if fan_in_fan_out: |
| | | self.weight.data = self.weight.data.transpose(0, 1) |
| | | self.weight.data = self.weight.data.T |
| | | |
| | | def reset_parameters(self): |
| | | nn.Linear.reset_parameters(self) |
| | |
| | | |
| | | def train(self, mode: bool = True): |
| | | def T(w): |
| | | return w.transpose(0, 1) if self.fan_in_fan_out else w |
| | | return w.T if self.fan_in_fan_out else w |
| | | nn.Linear.train(self, mode) |
| | | if mode: |
| | | if self.merge_weights and self.merged: |
| | | # Make sure that the weights are not merged |
| | | if self.r > 0: |
| | | self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling |
| | | self.merged = False |
| | | else: |
| | | if self.merge_weights and not self.merged: |
| | | # Merge the weights and mark it |
| | | if self.r > 0: |
| | | self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling |
| | | self.merged = True |
| | | if self.merge_weights and self.merged: |
| | | # Make sure that the weights are not merged |
| | | if self.r > 0: |
| | | self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling |
| | | self.merged = False |
| | | |
| | | def eval(self): |
| | | def T(w): |
| | | return w.T if self.fan_in_fan_out else w |
| | | nn.Linear.eval(self) |
| | | if self.merge_weights and not self.merged: |
| | | # Merge the weights and mark it |
| | | if self.r > 0: |
| | | self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling |
| | | self.merged = True |
| | | |
| | | def forward(self, x: torch.Tensor): |
| | | def T(w): |
| | | return w.transpose(0, 1) if self.fan_in_fan_out else w |
| | | return w.T if self.fan_in_fan_out else w |
| | | if self.r > 0 and not self.merged: |
| | | result = F.linear(x, T(self.weight), bias=self.bias) |
| | | result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling |
| | | result = F.linear(x, T(self.weight), bias=self.bias) |
| | | if self.r > 0: |
| | | result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling |
| | | return result |
| | | else: |
| | | return F.linear(x, T(self.weight), bias=self.bias) |
| | |
| | | class MergedLinear(nn.Linear, LoRALayer): |
| | | # LoRA implemented in a dense layer |
| | | def __init__( |
| | | self, |
| | | in_features: int, |
| | | out_features: int, |
| | | r: int = 0, |
| | | lora_alpha: int = 1, |
| | | self, |
| | | in_features: int, |
| | | out_features: int, |
| | | r: int = 0, |
| | | lora_alpha: int = 1, |
| | | lora_dropout: float = 0., |
| | | enable_lora: List[bool] = [False], |
| | | fan_in_fan_out: bool = False, |
| | |
| | | self.lora_ind = self.lora_ind.view(-1) |
| | | self.reset_parameters() |
| | | if fan_in_fan_out: |
| | | self.weight.data = self.weight.data.transpose(0, 1) |
| | | self.weight.data = self.weight.data.T |
| | | |
| | | def reset_parameters(self): |
| | | nn.Linear.reset_parameters(self) |
| | |
| | | |
| | | def train(self, mode: bool = True): |
| | | def T(w): |
| | | return w.transpose(0, 1) if self.fan_in_fan_out else w |
| | | return w.T if self.fan_in_fan_out else w |
| | | nn.Linear.train(self, mode) |
| | | if mode: |
| | | if self.merge_weights and self.merged: |
| | | # Make sure that the weights are not merged |
| | | if self.r > 0 and any(self.enable_lora): |
| | | delta_w = F.conv1d( |
| | | self.lora_A.data.unsqueeze(0), |
| | | self.lora_B.data.unsqueeze(-1), |
| | | groups=sum(self.enable_lora) |
| | | ).squeeze(0) |
| | | self.weight.data -= self.zero_pad(T(delta_w * self.scaling)) |
| | | self.merged = False |
| | | else: |
| | | if self.merge_weights and not self.merged: |
| | | # Merge the weights and mark it |
| | | if self.r > 0 and any(self.enable_lora): |
| | | delta_w = F.conv1d( |
| | | self.lora_A.data.unsqueeze(0), |
| | | self.lora_B.data.unsqueeze(-1), |
| | | groups=sum(self.enable_lora) |
| | | ).squeeze(0) |
| | | self.weight.data += self.zero_pad(T(delta_w * self.scaling)) |
| | | self.merged = True |
| | | if self.merge_weights and self.merged: |
| | | # Make sure that the weights are not merged |
| | | if self.r > 0 and any(self.enable_lora): |
| | | delta_w = F.conv1d( |
| | | self.lora_A.data.unsqueeze(0), |
| | | self.lora_B.data.unsqueeze(-1), |
| | | groups=sum(self.enable_lora) |
| | | ).squeeze(0) |
| | | self.weight.data -= self.zero_pad(T(delta_w * self.scaling)) |
| | | self.merged = False |
| | | |
| | | def eval(self): |
| | | def T(w): |
| | | return w.T if self.fan_in_fan_out else w |
| | | nn.Linear.eval(self) |
| | | if self.merge_weights and not self.merged: |
| | | # Merge the weights and mark it |
| | | if self.r > 0 and any(self.enable_lora): |
| | | delta_w = F.conv1d( |
| | | self.lora_A.data.unsqueeze(0), |
| | | self.lora_B.data.unsqueeze(-1), |
| | | groups=sum(self.enable_lora) |
| | | ).squeeze(0) |
| | | self.weight.data += self.zero_pad(T(delta_w * self.scaling)) |
| | | self.merged = True |
| | | |
| | | def forward(self, x: torch.Tensor): |
| | | def T(w): |
| | | return w.transpose(0, 1) if self.fan_in_fan_out else w |
| | | return w.T if self.fan_in_fan_out else w |
| | | if self.merged: |
| | | return F.linear(x, T(self.weight), bias=self.bias) |
| | | else: |
| | |
| | | if self.r > 0: |
| | | after_A = F.linear(self.lora_dropout(x), self.lora_A) |
| | | after_B = F.conv1d( |
| | | after_A.transpose(-2, -1), |
| | | self.lora_B.unsqueeze(-1), |
| | | after_A.transpose(-2, -1), |
| | | self.lora_B.unsqueeze(-1), |
| | | groups=sum(self.enable_lora) |
| | | ).transpose(-2, -1) |
| | | result += self.zero_pad(after_B) * self.scaling |
| | | return result |
| | | |
| | | |
| | | class ConvLoRA(nn.Module, LoRALayer): |
| | | def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): |
| | | super(ConvLoRA, self).__init__() |
| | | self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) |
| | | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) |
| | | assert isinstance(kernel_size, int) |
| | | |
| | | class Conv2d(nn.Conv2d, LoRALayer): |
| | | # LoRA implemented in a dense layer |
| | | def __init__( |
| | | self, |
| | | in_channels: int, |
| | | out_channels: int, |
| | | kernel_size: int, |
| | | r: int = 0, |
| | | lora_alpha: int = 1, |
| | | lora_dropout: float = 0., |
| | | merge_weights: bool = True, |
| | | **kwargs |
| | | ): |
| | | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) |
| | | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
| | | merge_weights=merge_weights) |
| | | assert type(kernel_size) is int |
| | | # Actual trainable parameters |
| | | if r > 0: |
| | | self.lora_A = nn.Parameter( |
| | | self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) |
| | | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size)) |
| | | ) |
| | | self.lora_B = nn.Parameter( |
| | | self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size)) |
| | | self.weight.new_zeros((out_channels*kernel_size, r*kernel_size)) |
| | | ) |
| | | self.scaling = self.lora_alpha / self.r |
| | | # Freezing the pre-trained weight matrix |
| | | self.conv.weight.requires_grad = False |
| | | self.weight.requires_grad = False |
| | | self.reset_parameters() |
| | | self.merged = False |
| | | |
| | | def reset_parameters(self): |
| | | self.conv.reset_parameters() |
| | | nn.Conv2d.reset_parameters(self) |
| | | if hasattr(self, 'lora_A'): |
| | | # initialize A the same way as the default for nn.Linear and B to zero |
| | | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| | | nn.init.zeros_(self.lora_B) |
| | | |
| | | def train(self, mode=True): |
| | | super(ConvLoRA, self).train(mode) |
| | | if mode: |
| | | if self.merge_weights and self.merged: |
| | | if self.r > 0: |
| | | # Make sure that the weights are not merged |
| | | self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling |
| | | self.merged = False |
| | | else: |
| | | if self.merge_weights and not self.merged: |
| | | if self.r > 0: |
| | | # Merge the weights and mark it |
| | | self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling |
| | | self.merged = True |
| | | def train(self, mode: bool = True): |
| | | nn.Conv2d.train(self, mode) |
| | | if self.merge_weights and self.merged: |
| | | # Make sure that the weights are not merged |
| | | self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling |
| | | self.merged = False |
| | | |
| | | def forward(self, x): |
| | | def eval(self): |
| | | nn.Conv2d.eval(self) |
| | | if self.merge_weights and not self.merged: |
| | | # Merge the weights and mark it |
| | | self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling |
| | | self.merged = True |
| | | |
| | | def forward(self, x: torch.Tensor): |
| | | if self.r > 0 and not self.merged: |
| | | return self.conv._conv_forward( |
| | | x, |
| | | self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling, |
| | | self.conv.bias |
| | | return F.conv2d( |
| | | x, |
| | | self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling, |
| | | self.bias, self.stride, self.padding, self.dilation, self.groups |
| | | ) |
| | | return self.conv(x) |
| | | |
| | | class Conv2d(ConvLoRA): |
| | | def __init__(self, *args, **kwargs): |
| | | super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs) |
| | | |
| | | class Conv1d(ConvLoRA): |
| | | def __init__(self, *args, **kwargs): |
| | | super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs) |
| | | |
| | | # Can Extend to other ones like this |
| | | |
| | | class Conv3d(ConvLoRA): |
| | | def __init__(self, *args, **kwargs): |
| | | super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs) |
| | | return nn.Conv2d.forward(self, x) |
| | | |