| | |
| | | import math |
| | | from typing import Optional, List |
| | | |
| | | class LoRALayer(): |
| | | |
| | | class LoRALayer: |
| | | def __init__( |
| | | self, |
| | | r: int, |
| | |
| | | self.r = r |
| | | self.lora_alpha = lora_alpha |
| | | # Optional dropout |
| | | if lora_dropout > 0.: |
| | | if lora_dropout > 0.0: |
| | | self.lora_dropout = nn.Dropout(p=lora_dropout) |
| | | else: |
| | | self.lora_dropout = lambda x: x |
| | |
| | | **kwargs |
| | | ): |
| | | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) |
| | | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, |
| | | merge_weights=merge_weights) |
| | | LoRALayer.__init__( |
| | | self, r=r, lora_alpha=lora_alpha, lora_dropout=0, merge_weights=merge_weights |
| | | ) |
| | | # Actual trainable parameters |
| | | if r > 0: |
| | | self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings))) |
| | |
| | | |
| | | def reset_parameters(self): |
| | | nn.Embedding.reset_parameters(self) |
| | | if hasattr(self, 'lora_A'): |
| | | if hasattr(self, "lora_A"): |
| | | # initialize A the same way as the default for nn.Linear and B to zero |
| | | nn.init.zeros_(self.lora_A) |
| | | nn.init.normal_(self.lora_B) |
| | |
| | | result = nn.Embedding.forward(self, x) |
| | | if self.r > 0: |
| | | after_A = F.embedding( |
| | | x, self.lora_A.T, self.padding_idx, self.max_norm, |
| | | self.norm_type, self.scale_grad_by_freq, self.sparse |
| | | x, |
| | | self.lora_A.T, |
| | | self.padding_idx, |
| | | self.max_norm, |
| | | self.norm_type, |
| | | self.scale_grad_by_freq, |
| | | self.sparse, |
| | | ) |
| | | result += (after_A @ self.lora_B.T) * self.scaling |
| | | return result |
| | |
| | | out_features: int, |
| | | r: int = 0, |
| | | lora_alpha: int = 1, |
| | | lora_dropout: float = 0., |
| | | fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) |
| | | lora_dropout: float = 0.0, |
| | | fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) |
| | | merge_weights: bool = True, |
| | | **kwargs |
| | | ): |
| | | nn.Linear.__init__(self, in_features, out_features, **kwargs) |
| | | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
| | | merge_weights=merge_weights) |
| | | LoRALayer.__init__( |
| | | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights |
| | | ) |
| | | |
| | | self.fan_in_fan_out = fan_in_fan_out |
| | | # Actual trainable parameters |
| | |
| | | |
| | | def reset_parameters(self): |
| | | nn.Linear.reset_parameters(self) |
| | | if hasattr(self, 'lora_A'): |
| | | if hasattr(self, "lora_A"): |
| | | # initialize A the same way as the default for nn.Linear and B to zero |
| | | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| | | nn.init.zeros_(self.lora_B) |
| | |
| | | def train(self, mode: bool = True): |
| | | def T(w): |
| | | return w.T if self.fan_in_fan_out else w |
| | | |
| | | nn.Linear.train(self, mode) |
| | | if self.merge_weights and self.merged: |
| | | # Make sure that the weights are not merged |
| | |
| | | def eval(self): |
| | | def T(w): |
| | | return w.T if self.fan_in_fan_out else w |
| | | |
| | | nn.Linear.eval(self) |
| | | if self.merge_weights and not self.merged: |
| | | # Merge the weights and mark it |
| | |
| | | def forward(self, x: torch.Tensor): |
| | | def T(w): |
| | | return w.T if self.fan_in_fan_out else w |
| | | |
| | | if self.r > 0 and not self.merged: |
| | | result = F.linear(x, T(self.weight), bias=self.bias) |
| | | if self.r > 0: |
| | |
| | | out_features: int, |
| | | r: int = 0, |
| | | lora_alpha: int = 1, |
| | | lora_dropout: float = 0., |
| | | lora_dropout: float = 0.0, |
| | | enable_lora: List[bool] = [False], |
| | | fan_in_fan_out: bool = False, |
| | | merge_weights: bool = True, |
| | | **kwargs |
| | | ): |
| | | nn.Linear.__init__(self, in_features, out_features, **kwargs) |
| | | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
| | | merge_weights=merge_weights) |
| | | assert out_features % len(enable_lora) == 0, \ |
| | | 'The length of enable_lora must divide out_features' |
| | | LoRALayer.__init__( |
| | | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights |
| | | ) |
| | | assert ( |
| | | out_features % len(enable_lora) == 0 |
| | | ), "The length of enable_lora must divide out_features" |
| | | self.enable_lora = enable_lora |
| | | self.fan_in_fan_out = fan_in_fan_out |
| | | # Actual trainable parameters |
| | | if r > 0 and any(enable_lora): |
| | | self.lora_A = nn.Parameter( |
| | | self.weight.new_zeros((r * sum(enable_lora), in_features))) |
| | | self.lora_A = nn.Parameter(self.weight.new_zeros((r * sum(enable_lora), in_features))) |
| | | self.lora_B = nn.Parameter( |
| | | self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r)) |
| | | ) # weights for Conv1D with groups=sum(enable_lora) |
| | | ) # weights for Conv1D with groups=sum(enable_lora) |
| | | self.scaling = self.lora_alpha / self.r |
| | | # Freezing the pre-trained weight matrix |
| | | self.weight.requires_grad = False |
| | | # Compute the indices |
| | | self.lora_ind = self.weight.new_zeros( |
| | | (out_features, ), dtype=torch.bool |
| | | ).view(len(enable_lora), -1) |
| | | self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view( |
| | | len(enable_lora), -1 |
| | | ) |
| | | self.lora_ind[enable_lora, :] = True |
| | | self.lora_ind = self.lora_ind.view(-1) |
| | | self.reset_parameters() |
| | |
| | | |
| | | def reset_parameters(self): |
| | | nn.Linear.reset_parameters(self) |
| | | if hasattr(self, 'lora_A'): |
| | | if hasattr(self, "lora_A"): |
| | | # initialize A the same way as the default for nn.Linear and B to zero |
| | | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| | | nn.init.zeros_(self.lora_B) |
| | |
| | | def train(self, mode: bool = True): |
| | | def T(w): |
| | | return w.T if self.fan_in_fan_out else w |
| | | |
| | | nn.Linear.train(self, mode) |
| | | if self.merge_weights and self.merged: |
| | | # Make sure that the weights are not merged |
| | |
| | | delta_w = F.conv1d( |
| | | self.lora_A.data.unsqueeze(0), |
| | | self.lora_B.data.unsqueeze(-1), |
| | | groups=sum(self.enable_lora) |
| | | groups=sum(self.enable_lora), |
| | | ).squeeze(0) |
| | | self.weight.data -= self.zero_pad(T(delta_w * self.scaling)) |
| | | self.merged = False |
| | |
| | | def eval(self): |
| | | def T(w): |
| | | return w.T if self.fan_in_fan_out else w |
| | | |
| | | nn.Linear.eval(self) |
| | | if self.merge_weights and not self.merged: |
| | | # Merge the weights and mark it |
| | |
| | | delta_w = F.conv1d( |
| | | self.lora_A.data.unsqueeze(0), |
| | | self.lora_B.data.unsqueeze(-1), |
| | | groups=sum(self.enable_lora) |
| | | groups=sum(self.enable_lora), |
| | | ).squeeze(0) |
| | | self.weight.data += self.zero_pad(T(delta_w * self.scaling)) |
| | | self.merged = True |
| | |
| | | def forward(self, x: torch.Tensor): |
| | | def T(w): |
| | | return w.T if self.fan_in_fan_out else w |
| | | |
| | | if self.merged: |
| | | return F.linear(x, T(self.weight), bias=self.bias) |
| | | else: |
| | |
| | | after_B = F.conv1d( |
| | | after_A.transpose(-2, -1), |
| | | self.lora_B.unsqueeze(-1), |
| | | groups=sum(self.enable_lora) |
| | | groups=sum(self.enable_lora), |
| | | ).transpose(-2, -1) |
| | | result += self.zero_pad(after_B) * self.scaling |
| | | return result |
| | |
| | | kernel_size: int, |
| | | r: int = 0, |
| | | lora_alpha: int = 1, |
| | | lora_dropout: float = 0., |
| | | lora_dropout: float = 0.0, |
| | | merge_weights: bool = True, |
| | | **kwargs |
| | | ): |
| | | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) |
| | | LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
| | | merge_weights=merge_weights) |
| | | LoRALayer.__init__( |
| | | self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights |
| | | ) |
| | | assert type(kernel_size) is int |
| | | # Actual trainable parameters |
| | | if r > 0: |
| | | self.lora_A = nn.Parameter( |
| | | self.weight.new_zeros((r*kernel_size, in_channels*kernel_size)) |
| | | self.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) |
| | | ) |
| | | self.lora_B = nn.Parameter( |
| | | self.weight.new_zeros((out_channels*kernel_size, r*kernel_size)) |
| | | self.weight.new_zeros((out_channels * kernel_size, r * kernel_size)) |
| | | ) |
| | | self.scaling = self.lora_alpha / self.r |
| | | # Freezing the pre-trained weight matrix |
| | |
| | | |
| | | def reset_parameters(self): |
| | | nn.Conv2d.reset_parameters(self) |
| | | if hasattr(self, 'lora_A'): |
| | | if hasattr(self, "lora_A"): |
| | | # initialize A the same way as the default for nn.Linear and B to zero |
| | | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| | | nn.init.zeros_(self.lora_B) |
| | |
| | | return F.conv2d( |
| | | x, |
| | | self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling, |
| | | self.bias, self.stride, self.padding, self.dilation, self.groups |
| | | self.bias, |
| | | self.stride, |
| | | self.padding, |
| | | self.dilation, |
| | | self.groups, |
| | | ) |
| | | return nn.Conv2d.forward(self, x) |
| | | |