From 16c41542451f399bdb716f1d7cad31cf52f6f8c3 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期二, 18 七月 2023 16:47:27 +0800
Subject: [PATCH] add lora finetune code
---
funasr/bin/train.py | 15 ++
funasr/modules/lora/layers.py | 319 +++++++++++++++++++++++++++++++++++++++++++++
funasr/modules/lora/__init__.py | 0
funasr/modules/lora/utils.py | 50 +++++++
funasr/bin/build_trainer.py | 2
5 files changed, 386 insertions(+), 0 deletions(-)
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index 891139a..bd30a83 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -144,6 +144,8 @@
args.patience = None
args.local_rank = local_rank
args.distributed = distributed
+ for key, value in kwargs.items():
+ args.key = value
ASRTask.finetune_args = args
return ASRTask
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 1dc3fb5..c9c0b02 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -28,6 +28,7 @@
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+from funasr.modules.lora.utils import mark_only_lora_as_trainable
def get_parser():
@@ -478,6 +479,18 @@
default=None,
help="oss bucket.",
)
+ parser.add_argument(
+ "--enable_lora",
+ type=str2bool,
+ default=False,
+ help="Apply lora for finetuning.",
+ )
+ parser.add_argument(
+ "--lora_bias",
+ type=str,
+ default="none",
+ help="oss bucket.",
+ )
return parser
@@ -521,6 +534,8 @@
dtype=getattr(torch, args.train_dtype),
device="cuda" if args.ngpu > 0 else "cpu",
)
+ if args.enable_lora:
+ mark_only_lora_as_trainable(model, args.lora_bias)
for t in args.freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
diff --git a/funasr/modules/lora/__init__.py b/funasr/modules/lora/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/modules/lora/__init__.py
diff --git a/funasr/modules/lora/layers.py b/funasr/modules/lora/layers.py
new file mode 100644
index 0000000..9115b28
--- /dev/null
+++ b/funasr/modules/lora/layers.py
@@ -0,0 +1,319 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import math
+from typing import Optional, List
+
+class LoRALayer():
+ def __init__(
+ self,
+ r: int,
+ lora_alpha: int,
+ lora_dropout: float,
+ merge_weights: bool,
+ ):
+ self.r = r
+ self.lora_alpha = lora_alpha
+ # Optional dropout
+ if lora_dropout > 0.:
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
+ else:
+ self.lora_dropout = lambda x: x
+ # Mark the weight as unmerged
+ self.merged = False
+ self.merge_weights = merge_weights
+
+
+class Embedding(nn.Embedding, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
+ merge_weights=merge_weights)
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
+ self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.Embedding.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.zeros_(self.lora_A)
+ nn.init.normal_(self.lora_B)
+
+ def train(self, mode: bool = True):
+ nn.Embedding.train(self, mode)
+ if mode:
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ if self.r > 0:
+ self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
+ self.merged = False
+ else:
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ if self.r > 0 and not self.merged:
+ result = nn.Embedding.forward(self, x)
+ after_A = F.embedding(
+ x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse
+ )
+ result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
+ return result
+ else:
+ return nn.Embedding.forward(self, x)
+
+
+class Linear(nn.Linear, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+
+ self.fan_in_fan_out = fan_in_fan_out
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+ if fan_in_fan_out:
+ self.weight.data = self.weight.data.transpose(0, 1)
+
+ def reset_parameters(self):
+ nn.Linear.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ def train(self, mode: bool = True):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+ nn.Linear.train(self, mode)
+ if mode:
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ if self.r > 0:
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = False
+ else:
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+ if self.r > 0 and not self.merged:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
+ return result
+ else:
+ return F.linear(x, T(self.weight), bias=self.bias)
+
+
+class MergedLinear(nn.Linear, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ enable_lora: List[bool] = [False],
+ fan_in_fan_out: bool = False,
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+ assert out_features % len(enable_lora) == 0, \
+ 'The length of enable_lora must divide out_features'
+ self.enable_lora = enable_lora
+ self.fan_in_fan_out = fan_in_fan_out
+ # Actual trainable parameters
+ if r > 0 and any(enable_lora):
+ self.lora_A = nn.Parameter(
+ self.weight.new_zeros((r * sum(enable_lora), in_features)))
+ self.lora_B = nn.Parameter(
+ self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
+ ) # weights for Conv1D with groups=sum(enable_lora)
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ # Compute the indices
+ self.lora_ind = self.weight.new_zeros(
+ (out_features, ), dtype=torch.bool
+ ).view(len(enable_lora), -1)
+ self.lora_ind[enable_lora, :] = True
+ self.lora_ind = self.lora_ind.view(-1)
+ self.reset_parameters()
+ if fan_in_fan_out:
+ self.weight.data = self.weight.data.transpose(0, 1)
+
+ def reset_parameters(self):
+ nn.Linear.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ def zero_pad(self, x):
+ result = x.new_zeros((*x.shape[:-1], self.out_features))
+ result = result.view(-1, self.out_features)
+ result[:, self.lora_ind] = x.reshape(
+ -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
+ )
+ return result.view((*x.shape[:-1], self.out_features))
+
+ def train(self, mode: bool = True):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+ nn.Linear.train(self, mode)
+ if mode:
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ if self.r > 0 and any(self.enable_lora):
+ delta_w = F.conv1d(
+ self.lora_A.data.unsqueeze(0),
+ self.lora_B.data.unsqueeze(-1),
+ groups=sum(self.enable_lora)
+ ).squeeze(0)
+ self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
+ self.merged = False
+ else:
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0 and any(self.enable_lora):
+ delta_w = F.conv1d(
+ self.lora_A.data.unsqueeze(0),
+ self.lora_B.data.unsqueeze(-1),
+ groups=sum(self.enable_lora)
+ ).squeeze(0)
+ self.weight.data += self.zero_pad(T(delta_w * self.scaling))
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+ if self.merged:
+ return F.linear(x, T(self.weight), bias=self.bias)
+ else:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ if self.r > 0:
+ after_A = F.linear(self.lora_dropout(x), self.lora_A)
+ after_B = F.conv1d(
+ after_A.transpose(-2, -1),
+ self.lora_B.unsqueeze(-1),
+ groups=sum(self.enable_lora)
+ ).transpose(-2, -1)
+ result += self.zero_pad(after_B) * self.scaling
+ return result
+
+
+class ConvLoRA(nn.Module, LoRALayer):
+ def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
+ super(ConvLoRA, self).__init__()
+ self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
+ assert isinstance(kernel_size, int)
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(
+ self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
+ )
+ self.lora_B = nn.Parameter(
+ self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
+ )
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.conv.weight.requires_grad = False
+ self.reset_parameters()
+ self.merged = False
+
+ def reset_parameters(self):
+ self.conv.reset_parameters()
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ def train(self, mode=True):
+ super(ConvLoRA, self).train(mode)
+ if mode:
+ if self.merge_weights and self.merged:
+ if self.r > 0:
+ # Make sure that the weights are not merged
+ self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
+ self.merged = False
+ else:
+ if self.merge_weights and not self.merged:
+ if self.r > 0:
+ # Merge the weights and mark it
+ self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
+ self.merged = True
+
+ def forward(self, x):
+ if self.r > 0 and not self.merged:
+ return self.conv._conv_forward(
+ x,
+ self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
+ self.conv.bias
+ )
+ return self.conv(x)
+
+class Conv2d(ConvLoRA):
+ def __init__(self, *args, **kwargs):
+ super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)
+
+class Conv1d(ConvLoRA):
+ def __init__(self, *args, **kwargs):
+ super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)
+
+# Can Extend to other ones like this
+
+class Conv3d(ConvLoRA):
+ def __init__(self, *args, **kwargs):
+ super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)
+
diff --git a/funasr/modules/lora/utils.py b/funasr/modules/lora/utils.py
new file mode 100644
index 0000000..e18bf44
--- /dev/null
+++ b/funasr/modules/lora/utils.py
@@ -0,0 +1,50 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+import torch
+import torch.nn as nn
+
+from typing import Dict
+
+from .layers import LoRALayer
+
+
+def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
+ for n, p in model.named_parameters():
+ if 'lora_' not in n and 'cif' not in n:
+ p.requires_grad = False
+ if bias == 'none':
+ return
+ elif bias == 'all':
+ for n, p in model.named_parameters():
+ if 'bias' in n:
+ p.requires_grad = True
+ elif bias == 'lora_only':
+ for m in model.modules():
+ if isinstance(m, LoRALayer) and \
+ hasattr(m, 'bias') and \
+ m.bias is not None:
+ m.bias.requires_grad = True
+ else:
+ raise NotImplementedError
+
+
+def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
+ my_state_dict = model.state_dict()
+ if bias == 'none':
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
+ elif bias == 'all':
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
+ elif bias == 'lora_only':
+ to_return = {}
+ for k in my_state_dict:
+ if 'lora_' in k:
+ to_return[k] = my_state_dict[k]
+ bias_name = k.split('lora_')[0]+'bias'
+ if bias_name in my_state_dict:
+ to_return[bias_name] = my_state_dict[bias_name]
+ return to_return
+ else:
+ raise NotImplementedError
+
--
Gitblit v1.9.1