From 36c43d4c9f3ae98f026889b2f5f9726826a208d8 Mon Sep 17 00:00:00 2001
From: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 20 七月 2023 18:33:54 +0800
Subject: [PATCH] add lora finetune code
---
funasr/modules/attention.py | 16 ++
funasr/modules/lora/layers.py | 248 +++++++++++++++++++++++++------------------------
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py | 16 ++
funasr/bin/build_trainer.py | 15 ++
4 files changed, 164 insertions(+), 131 deletions(-)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py
index 1935258..eb24e82 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py
@@ -19,7 +19,8 @@
work_dir=params.output_dir,
batch_bins=params.batch_bins,
max_epoch=params.max_epoch,
- lr=params.lr)
+ lr=params.lr,
+ mate_params=params.param_dict)
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
@@ -30,7 +31,18 @@
params.data_path = "./example_data/" # 鏁版嵁璺緞
params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
- params.max_epoch = 50 # 鏈�澶ц缁冭疆鏁�
+ params.max_epoch = 5 # 鏈�澶ц缁冭疆鏁�
params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
+ init_param = []
+ freeze_param = []
+ ignore_init_mismatch = True
+ use_lora = False
+ params.param_dict = {"init_param":init_param, "freeze_param": freeze_param, "ignore_init_mismatch": ignore_init_mismatch}
+ if use_lora:
+ enable_lora = True
+ lora_bias = "all"
+ lora_params = {"lora_list":['q','v'], "lora_rank":8, "lora_alpha":16, "lora_dropout":0.1}
+ lora_config = {"enable_lora": enable_lora, "lora_bias": lora_bias, "lora_params": lora_params}
+ params.param_dict.update(lora_config)
modelscope_finetune(params)
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index 0f87186..b794484 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -92,6 +92,14 @@
for key, value in finetune_configs.items():
if hasattr(args, key):
setattr(args, key, value)
+ if mate_params is not None:
+ for key, value in mate_params.items():
+ if hasattr(args, key):
+ setattr(args, key, value)
+ if mate_params is not None and "lora_params" in mate_params:
+ lora_params = mate_params['lora_params']
+ configs['encoder_conf'].update(lora_params)
+ configs['decoder_conf'].update(lora_params)
# prepare data
args.dataset_type = dataset_type
@@ -106,6 +114,9 @@
else:
raise ValueError(f"Not supported dataset_type={args.dataset_type}")
args.init_param = [init_param]
+ if mate_params is not None and "init_param" in mate_params:
+ if len(mate_params["init_param"]) != 0:
+ args.init_param = mate_params["init_param"]
args.cmvn_file = cmvn_file
if os.path.exists(seg_dict_file):
args.seg_dict_file = seg_dict_file
@@ -144,10 +155,6 @@
args.patience = None
args.local_rank = local_rank
args.distributed = distributed
- if mate_params is not None:
- for key, value in mate_params.items():
- if hasattr(args, key):
- setattr(args, key, value)
ASRTask.finetune_args = args
return ASRTask
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index f01e340..ab59493 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -338,7 +338,10 @@
else:
self.linear_out = nn.Linear(n_feat, n_feat)
lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
- self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+ if lora_qkv_list == [False, False, False]:
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ else:
+ self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
else:
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
@@ -562,11 +565,18 @@
if lora_list is not None:
if "q" in lora_list:
self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
lora_kv_list = ["k" in lora_list, "v" in lora_list]
- self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
- r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+ if lora_kv_list == [False, False]:
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ else:
+ self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
+ r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
if "o" in lora_list:
self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
else:
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
diff --git a/funasr/modules/lora/layers.py b/funasr/modules/lora/layers.py
index 9115b28..76f046c 100644
--- a/funasr/modules/lora/layers.py
+++ b/funasr/modules/lora/layers.py
@@ -11,9 +11,9 @@
class LoRALayer():
def __init__(
- self,
- r: int,
- lora_alpha: int,
+ self,
+ r: int,
+ lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
@@ -61,40 +61,42 @@
def train(self, mode: bool = True):
nn.Embedding.train(self, mode)
- if mode:
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- if self.r > 0:
- self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
- self.merged = False
- else:
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- if self.r > 0:
- self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
- self.merged = True
-
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ if self.r > 0:
+ self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
+ self.merged = False
+
+ def eval(self):
+ nn.Linear.eval(self)
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
+ self.merged = True
+
def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
result = nn.Embedding.forward(self, x)
- after_A = F.embedding(
- x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse
- )
- result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
+ if self.r > 0:
+ after_A = F.embedding(
+ x, self.lora_A.T, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse
+ )
+ result += (after_A @ self.lora_B.T) * self.scaling
return result
else:
return nn.Embedding.forward(self, x)
-
+
class Linear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
- self,
- in_features: int,
- out_features: int,
- r: int = 0,
- lora_alpha: int = 1,
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ lora_alpha: int = 1,
lora_dropout: float = 0.,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = True,
@@ -114,7 +116,7 @@
self.weight.requires_grad = False
self.reset_parameters()
if fan_in_fan_out:
- self.weight.data = self.weight.data.transpose(0, 1)
+ self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
@@ -125,27 +127,31 @@
def train(self, mode: bool = True):
def T(w):
- return w.transpose(0, 1) if self.fan_in_fan_out else w
+ return w.T if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
- if mode:
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- if self.r > 0:
- self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
- self.merged = False
- else:
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- if self.r > 0:
- self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
- self.merged = True
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ if self.r > 0:
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = False
+
+ def eval(self):
+ def T(w):
+ return w.T if self.fan_in_fan_out else w
+ nn.Linear.eval(self)
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
- return w.transpose(0, 1) if self.fan_in_fan_out else w
+ return w.T if self.fan_in_fan_out else w
if self.r > 0 and not self.merged:
- result = F.linear(x, T(self.weight), bias=self.bias)
- result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ if self.r > 0:
+ result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
return result
else:
return F.linear(x, T(self.weight), bias=self.bias)
@@ -154,11 +160,11 @@
class MergedLinear(nn.Linear, LoRALayer):
# LoRA implemented in a dense layer
def __init__(
- self,
- in_features: int,
- out_features: int,
- r: int = 0,
- lora_alpha: int = 1,
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ lora_alpha: int = 1,
lora_dropout: float = 0.,
enable_lora: List[bool] = [False],
fan_in_fan_out: bool = False,
@@ -190,7 +196,7 @@
self.lora_ind = self.lora_ind.view(-1)
self.reset_parameters()
if fan_in_fan_out:
- self.weight.data = self.weight.data.transpose(0, 1)
+ self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
@@ -209,34 +215,37 @@
def train(self, mode: bool = True):
def T(w):
- return w.transpose(0, 1) if self.fan_in_fan_out else w
+ return w.T if self.fan_in_fan_out else w
nn.Linear.train(self, mode)
- if mode:
- if self.merge_weights and self.merged:
- # Make sure that the weights are not merged
- if self.r > 0 and any(self.enable_lora):
- delta_w = F.conv1d(
- self.lora_A.data.unsqueeze(0),
- self.lora_B.data.unsqueeze(-1),
- groups=sum(self.enable_lora)
- ).squeeze(0)
- self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
- self.merged = False
- else:
- if self.merge_weights and not self.merged:
- # Merge the weights and mark it
- if self.r > 0 and any(self.enable_lora):
- delta_w = F.conv1d(
- self.lora_A.data.unsqueeze(0),
- self.lora_B.data.unsqueeze(-1),
- groups=sum(self.enable_lora)
- ).squeeze(0)
- self.weight.data += self.zero_pad(T(delta_w * self.scaling))
- self.merged = True
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ if self.r > 0 and any(self.enable_lora):
+ delta_w = F.conv1d(
+ self.lora_A.data.unsqueeze(0),
+ self.lora_B.data.unsqueeze(-1),
+ groups=sum(self.enable_lora)
+ ).squeeze(0)
+ self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
+ self.merged = False
+
+ def eval(self):
+ def T(w):
+ return w.T if self.fan_in_fan_out else w
+ nn.Linear.eval(self)
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0 and any(self.enable_lora):
+ delta_w = F.conv1d(
+ self.lora_A.data.unsqueeze(0),
+ self.lora_B.data.unsqueeze(-1),
+ groups=sum(self.enable_lora)
+ ).squeeze(0)
+ self.weight.data += self.zero_pad(T(delta_w * self.scaling))
+ self.merged = True
def forward(self, x: torch.Tensor):
def T(w):
- return w.transpose(0, 1) if self.fan_in_fan_out else w
+ return w.T if self.fan_in_fan_out else w
if self.merged:
return F.linear(x, T(self.weight), bias=self.bias)
else:
@@ -244,76 +253,71 @@
if self.r > 0:
after_A = F.linear(self.lora_dropout(x), self.lora_A)
after_B = F.conv1d(
- after_A.transpose(-2, -1),
- self.lora_B.unsqueeze(-1),
+ after_A.transpose(-2, -1),
+ self.lora_B.unsqueeze(-1),
groups=sum(self.enable_lora)
).transpose(-2, -1)
result += self.zero_pad(after_B) * self.scaling
return result
-
-class ConvLoRA(nn.Module, LoRALayer):
- def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
- super(ConvLoRA, self).__init__()
- self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
- LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
- assert isinstance(kernel_size, int)
+
+class Conv2d(nn.Conv2d, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+ assert type(kernel_size) is int
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(
- self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
+ self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
)
self.lora_B = nn.Parameter(
- self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size))
+ self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
)
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
- self.conv.weight.requires_grad = False
+ self.weight.requires_grad = False
self.reset_parameters()
- self.merged = False
def reset_parameters(self):
- self.conv.reset_parameters()
+ nn.Conv2d.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
- def train(self, mode=True):
- super(ConvLoRA, self).train(mode)
- if mode:
- if self.merge_weights and self.merged:
- if self.r > 0:
- # Make sure that the weights are not merged
- self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
- self.merged = False
- else:
- if self.merge_weights and not self.merged:
- if self.r > 0:
- # Merge the weights and mark it
- self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
- self.merged = True
+ def train(self, mode: bool = True):
+ nn.Conv2d.train(self, mode)
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+ self.merged = False
- def forward(self, x):
+ def eval(self):
+ nn.Conv2d.eval(self)
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
if self.r > 0 and not self.merged:
- return self.conv._conv_forward(
- x,
- self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
- self.conv.bias
+ return F.conv2d(
+ x,
+ self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
+ self.bias, self.stride, self.padding, self.dilation, self.groups
)
- return self.conv(x)
-
-class Conv2d(ConvLoRA):
- def __init__(self, *args, **kwargs):
- super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)
-
-class Conv1d(ConvLoRA):
- def __init__(self, *args, **kwargs):
- super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)
-
-# Can Extend to other ones like this
-
-class Conv3d(ConvLoRA):
- def __init__(self, *args, **kwargs):
- super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)
+ return nn.Conv2d.forward(self, x)
--
Gitblit v1.9.1