From a86d1676098f86444528646a409857ab02a4bbcb Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 21 七月 2023 15:26:44 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/bin/train.py | 15 +
funasr/modules/attention.py | 44 +++
funasr/modules/lora/layers.py | 323 +++++++++++++++++++++++++++++
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md | 55 ++++
funasr/models/encoder/sanm_encoder.py | 12 +
funasr/modules/lora/__init__.py | 0
funasr/modules/lora/utils.py | 50 ++++
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py | 18 +
egs_modelscope/asr/TEMPLATE/README.md | 55 +++-
funasr/bin/build_trainer.py | 14 +
funasr/tasks/abs_task.py | 15 +
funasr/models/decoder/sanm_decoder.py | 6
12 files changed, 566 insertions(+), 41 deletions(-)
diff --git a/egs_modelscope/asr/TEMPLATE/README.md b/egs_modelscope/asr/TEMPLATE/README.md
index 0219c5b..cf0ba84 100644
--- a/egs_modelscope/asr/TEMPLATE/README.md
+++ b/egs_modelscope/asr/TEMPLATE/README.md
@@ -1,6 +1,6 @@
# Speech Recognition
-> **Note**:
+> **Note**:
> The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetine. Here we take the typic models as examples to demonstrate the usage.
## Inference
@@ -36,7 +36,7 @@
param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
chunk_stride = chunk_size[1] * 960 # 600ms銆�480ms
# first chunk, 600ms
-speech_chunk = speech[0:chunk_stride]
+speech_chunk = speech[0:chunk_stride]
rec_result = inference_pipeline(audio_in=speech_chunk, param_dict=param_dict)
print(rec_result)
# next chunk, 600ms
@@ -101,16 +101,16 @@
- `task`: `Tasks.auto_speech_recognition`
- `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
- `ngpu`: `1` (Default), decoding on GPU. If ngpu=0, decoding on CPU
-- `ncpu`: `1` (Default), sets the number of threads used for intraop parallelism on CPU
+- `ncpu`: `1` (Default), sets the number of threads used for intraop parallelism on CPU
- `output_dir`: `None` (Default), the output path of results if set
- `batch_size`: `1` (Default), batch size when decoding
#### Infer pipeline
-- `audio_in`: the input to decode, which could be:
+- `audio_in`: the input to decode, which could be:
- wav_path, `e.g.`: asr_example.wav,
- - pcm_path, `e.g.`: asr_example.pcm,
+ - pcm_path, `e.g.`: asr_example.pcm,
- audio bytes stream, `e.g.`: bytes data from a microphone
- audio sample point锛宍e.g.`: `audio, rate = soundfile.read("asr_example_zh.wav")`, the dtype is numpy.ndarray or torch.Tensor
- - wav.scp, kaldi style wav list (`wav_id \t wav_path`), `e.g.`:
+ - wav.scp, kaldi style wav list (`wav_id \t wav_path`), `e.g.`:
```text
asr_example1 ./audios/asr_example1.wav
asr_example2 ./audios/asr_example2.wav
@@ -168,15 +168,19 @@
[finetune.py](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/asr/TEMPLATE/finetune.py)
```python
import os
+
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
-from modelscope.msdatasets.audio.asr_dataset import ASRDataset
+
+from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
+
def modelscope_finetune(params):
if not os.path.exists(params.output_dir):
os.makedirs(params.output_dir, exist_ok=True)
# dataset split ["train", "validation"]
- ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr')
+ ds_dict = MsDataset.load(params.data_path)
kwargs = dict(
model=params.model,
data_dir=ds_dict,
@@ -184,21 +188,32 @@
work_dir=params.output_dir,
batch_bins=params.batch_bins,
max_epoch=params.max_epoch,
- lr=params.lr)
+ lr=params.lr,
+ mate_params=params.param_dict)
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
if __name__ == '__main__':
- from funasr.utils.modelscope_param import modelscope_args
- params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
- params.output_dir = "./checkpoint" # 妯″瀷淇濆瓨璺緞
- params.data_path = "speech_asr_aishell1_trainsets" # 鏁版嵁璺緞锛屽彲浠ヤ负modelscope涓凡涓婁紶鏁版嵁锛屼篃鍙互鏄湰鍦版暟鎹�
- params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
- params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
- params.max_epoch = 50 # 鏈�澶ц缁冭疆鏁�
- params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
-
+ params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", data_path="./data")
+ params.output_dir = "./checkpoint" # m妯″瀷淇濆瓨璺緞
+ params.data_path = "./example_data/" # 鏁版嵁璺緞
+ params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
+ params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
+ params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
+ params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
+ init_param = [] # 鍒濆妯″瀷璺緞锛岄粯璁ゅ姞杞絤odelscope妯″瀷鍒濆鍖栵紝渚嬪: ["checkpoint/20epoch.pb"]
+ freeze_param = [] # 妯″瀷鍙傛暟freeze, 渚嬪: ["encoder"]
+ ignore_init_mismatch = True # 鏄惁蹇界暐妯″瀷鍙傛暟鍒濆鍖栦笉鍖归厤
+ use_lora = False # 鏄惁浣跨敤lora杩涜妯″瀷寰皟
+ params.param_dict = {"init_param":init_param, "freeze_param": freeze_param, "ignore_init_mismatch": ignore_init_mismatch}
+ if use_lora:
+ enable_lora = True
+ lora_bias = "all"
+ lora_params = {"lora_list":['q','v'], "lora_rank":8, "lora_alpha":16, "lora_dropout":0.1}
+ lora_config = {"enable_lora": enable_lora, "lora_bias": lora_bias, "lora_params": lora_params}
+ params.param_dict.update(lora_config)
+
modelscope_finetune(params)
```
@@ -215,6 +230,10 @@
- `batch_bins`: batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
- `max_epoch`: number of training epoch
- `lr`: learning rate
+ - `init_param`: init model path, load modelscope model initialization by default. For example: ["checkpoint/20epoch.pb"]
+ - `freeze_param`: Freeze model parameters. For example锛歔"encoder"]
+ - `ignore_init_mismatch`: Ignore size mismatch when loading pre-trained model
+ - `use_lora`: Fine-tuning model use lora, more detail please refer to [LORA](https://arxiv.org/pdf/2106.09685.pdf)
- Training data formats锛�
```sh
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
index 4e06daf..edc2cf1 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
@@ -2,14 +2,6 @@
- Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary>
- Model size: 220M
-# Environments
-- date: `Tue Nov 22 18:48:39 CST 2022`
-- python version: `3.7.12`
-- FunASR version: `0.1.0`
-- pytorch version: `pytorch 1.7.0`
-- Git hash: ``
-- Commit date: ``
-
# Beachmark Results
## AISHELL-1
@@ -73,3 +65,50 @@
|SPEECHIO_ASR_ZH000013| 2.57 | 2.25 |
|SPEECHIO_ASR_ZH000014| 3.86 | 3.08 |
|SPEECHIO_ASR_ZH000015| 3.34 | 2.67 |
+
+
+# Fine-tuning Results
+
+## Fine-tuning
+- Train config:
+ - Training data: aishell-1
+ - Training info: lr 0.0002, dataset_type: small, batch bins 2000, 2 gpu, acc_grad 1, 20 epochs
+ - Decoding info: beam_size 1, average_num 10
+
+| model | dev cer(%) | test cer(%) |
+|:---------:|:-------------:|:-------------:|
+| Pretrain | 1.75 |1.95 |
+| Finetune | 1.62 |1.78 |
+
+- Train config:
+ - Training data: 16k sichuan dialect
+ - Training info: lr 0.0002, dataset_type: small, batch bins 2000, 2 gpu, acc_grad 1, 20 epochs
+ - Decoding info: beam_size 1, average_num 10
+
+
+| model | Training Data(h) | cn cer(%) | sichuan cer(%) |
+|:--------:|:-------------:|:-------:|:------------:|
+| Pretrain | | 8.57 | 19.81 |
+| Finetune | 50 | 8.8 | 12 |
+| | 100 | 9.24 | 11.63 |
+| | 200 | 9.82 | 10.47 |
+| | 300 | 9.95 | 10.44 |
+| | 1000 | 9.99 | 9.78 |
+
+
+## Lora Fine-tuning
+- Train config:
+ - Training data: 16k sichuan dialect
+ - Training info: lr 0.0002, dataset_type: small, batch bins 2000, 2 gpu, acc_grad 1, 20 epochs
+ - Lora info: lora_bias: "all", lora_list ['q','v'], lora_rank:8, lora_alpha:16, lora_dropout:0.1
+ - Decoding info: beam_size 1, average_num 10
+
+| model | Training Data(h) | Trainable Parameters(M) | cn cer(%) | sichuan cer(%) |
+|:-------------:|:----------------:|:-----------------------:|:---------:|:--------------:|
+| Pretrain | | | 8.57 | 19.81 |
+| | | | | |
+| Finetune | 50 | 220.9 | 8.8 | 12 |
+| Lora Finetune | 50 | 2.29 | 9.13 | 12.13 |
+| | | | | |
+| Finetune | 200 | 220.9 | 9.82 | 10.47 |
+| Lora Finetune | 200 | 2.29 | 9.21 | 11.28 |
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py
index 1935258..993f8ed 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py
@@ -19,7 +19,8 @@
work_dir=params.output_dir,
batch_bins=params.batch_bins,
max_epoch=params.max_epoch,
- lr=params.lr)
+ lr=params.lr,
+ mate_params=params.param_dict)
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
@@ -30,7 +31,18 @@
params.data_path = "./example_data/" # 鏁版嵁璺緞
params.dataset_type = "small" # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
params.batch_bins = 2000 # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
- params.max_epoch = 50 # 鏈�澶ц缁冭疆鏁�
- params.lr = 0.00005 # 璁剧疆瀛︿範鐜�
+ params.max_epoch = 20 # 鏈�澶ц缁冭疆鏁�
+ params.lr = 0.0002 # 璁剧疆瀛︿範鐜�
+ init_param = [] # 鍒濆妯″瀷璺緞锛岄粯璁ゅ姞杞絤odelscope妯″瀷鍒濆鍖栵紝渚嬪: ["checkpoint/20epoch.pb"]
+ freeze_param = [] # 妯″瀷鍙傛暟freeze, 渚嬪: ["encoder"]
+ ignore_init_mismatch = True # 鏄惁蹇界暐妯″瀷鍙傛暟鍒濆鍖栦笉鍖归厤
+ use_lora = False # 鏄惁浣跨敤lora杩涜妯″瀷寰皟
+ params.param_dict = {"init_param":init_param, "freeze_param": freeze_param, "ignore_init_mismatch": ignore_init_mismatch}
+ if use_lora:
+ enable_lora = True
+ lora_bias = "all"
+ lora_params = {"lora_list":['q','v'], "lora_rank":8, "lora_alpha":16, "lora_dropout":0.1}
+ lora_config = {"enable_lora": enable_lora, "lora_bias": lora_bias, "lora_params": lora_params}
+ params.param_dict.update(lora_config)
modelscope_finetune(params)
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index 891139a..e7f28ed 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -2,7 +2,6 @@
import yaml
-
def update_dct(fin_configs, root):
if root == {}:
return {}
@@ -55,7 +54,7 @@
scheduler_conf=None,
specaug=None,
specaug_conf=None,
- param_dict=None,
+ mate_params=None,
**kwargs):
mode = modelscope_dict['mode']
args, ASRTask = parse_args(mode=mode)
@@ -92,6 +91,14 @@
for key, value in finetune_configs.items():
if hasattr(args, key):
setattr(args, key, value)
+ if mate_params is not None:
+ for key, value in mate_params.items():
+ if hasattr(args, key):
+ setattr(args, key, value)
+ if mate_params is not None and "lora_params" in mate_params:
+ lora_params = mate_params['lora_params']
+ configs['encoder_conf'].update(lora_params)
+ configs['decoder_conf'].update(lora_params)
# prepare data
args.dataset_type = dataset_type
@@ -106,6 +113,9 @@
else:
raise ValueError(f"Not supported dataset_type={args.dataset_type}")
args.init_param = [init_param]
+ if mate_params is not None and "init_param" in mate_params:
+ if len(mate_params["init_param"]) != 0:
+ args.init_param = mate_params["init_param"]
args.cmvn_file = cmvn_file
if os.path.exists(seg_dict_file):
args.seg_dict_file = seg_dict_file
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 1dc3fb5..f5d10c4 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -28,6 +28,7 @@
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+from funasr.modules.lora.utils import mark_only_lora_as_trainable
def get_parser():
@@ -478,6 +479,18 @@
default=None,
help="oss bucket.",
)
+ parser.add_argument(
+ "--enable_lora",
+ type=str2bool,
+ default=False,
+ help="Apply lora for finetuning.",
+ )
+ parser.add_argument(
+ "--lora_bias",
+ type=str,
+ default="none",
+ help="lora bias.",
+ )
return parser
@@ -521,6 +534,8 @@
dtype=getattr(torch, args.train_dtype),
device="cuda" if args.ngpu > 0 else "cpu",
)
+ if args.enable_lora:
+ mark_only_lora_as_trainable(model, args.lora_bias)
for t in args.freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index d83f89f..c12e098 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -833,6 +833,10 @@
att_layer_num: int = 6,
kernel_size: int = 21,
sanm_shfit: int = 0,
+ lora_list: List[str] = None,
+ lora_rank: int = 8,
+ lora_alpha: int = 16,
+ lora_dropout: float = 0.1,
tf2torch_tensor_name_prefix_torch: str = "decoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
):
@@ -885,7 +889,7 @@
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate
+ attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 45163df..9e27d4a 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -146,6 +146,10 @@
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
+ lora_list: List[str] = None,
+ lora_rank: int = 8,
+ lora_alpha: int = 16,
+ lora_dropout: float = 0.1,
selfattention_layer_type: str = "sanm",
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
@@ -229,6 +233,10 @@
attention_dropout_rate,
kernel_size,
sanm_shfit,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
)
encoder_selfattn_layer_args = (
@@ -238,6 +246,10 @@
attention_dropout_rate,
kernel_size,
sanm_shfit,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
)
self.encoders0 = repeat(
1,
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index fcb3ed4..ab59493 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -15,6 +15,7 @@
import torch.nn.functional as F
from funasr.modules.nets_utils import make_pad_mask
+import funasr.modules.lora.layers as lora
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@@ -321,7 +322,7 @@
"""
- def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+ def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttentionSANM, self).__init__()
assert n_feat % n_head == 0
@@ -331,8 +332,19 @@
# self.linear_q = nn.Linear(n_feat, n_feat)
# self.linear_k = nn.Linear(n_feat, n_feat)
# self.linear_v = nn.Linear(n_feat, n_feat)
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ if lora_list is not None:
+ if "o" in lora_list:
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
+ if lora_qkv_list == [False, False, False]:
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ else:
+ self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
@@ -543,18 +555,32 @@
"""
- def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
+ def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttentionCrossAtt, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
- self.linear_q = nn.Linear(n_feat, n_feat)
- # self.linear_k = nn.Linear(n_feat, n_feat)
- # self.linear_v = nn.Linear(n_feat, n_feat)
- self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
- self.linear_out = nn.Linear(n_feat, n_feat)
+ if lora_list is not None:
+ if "q" in lora_list:
+ self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ lora_kv_list = ["k" in lora_list, "v" in lora_list]
+ if lora_kv_list == [False, False]:
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ else:
+ self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2,
+ r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+ if "o" in lora_list:
+ self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+ else:
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ else:
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+ self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
diff --git a/funasr/modules/lora/__init__.py b/funasr/modules/lora/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/modules/lora/__init__.py
diff --git a/funasr/modules/lora/layers.py b/funasr/modules/lora/layers.py
new file mode 100644
index 0000000..76f046c
--- /dev/null
+++ b/funasr/modules/lora/layers.py
@@ -0,0 +1,323 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import math
+from typing import Optional, List
+
+class LoRALayer():
+ def __init__(
+ self,
+ r: int,
+ lora_alpha: int,
+ lora_dropout: float,
+ merge_weights: bool,
+ ):
+ self.r = r
+ self.lora_alpha = lora_alpha
+ # Optional dropout
+ if lora_dropout > 0.:
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
+ else:
+ self.lora_dropout = lambda x: x
+ # Mark the weight as unmerged
+ self.merged = False
+ self.merge_weights = merge_weights
+
+
+class Embedding(nn.Embedding, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
+ merge_weights=merge_weights)
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
+ self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.Embedding.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.zeros_(self.lora_A)
+ nn.init.normal_(self.lora_B)
+
+ def train(self, mode: bool = True):
+ nn.Embedding.train(self, mode)
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ if self.r > 0:
+ self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
+ self.merged = False
+
+ def eval(self):
+ nn.Linear.eval(self)
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ if self.r > 0 and not self.merged:
+ result = nn.Embedding.forward(self, x)
+ if self.r > 0:
+ after_A = F.embedding(
+ x, self.lora_A.T, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse
+ )
+ result += (after_A @ self.lora_B.T) * self.scaling
+ return result
+ else:
+ return nn.Embedding.forward(self, x)
+
+
+class Linear(nn.Linear, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+
+ self.fan_in_fan_out = fan_in_fan_out
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+ if fan_in_fan_out:
+ self.weight.data = self.weight.data.T
+
+ def reset_parameters(self):
+ nn.Linear.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ def train(self, mode: bool = True):
+ def T(w):
+ return w.T if self.fan_in_fan_out else w
+ nn.Linear.train(self, mode)
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ if self.r > 0:
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = False
+
+ def eval(self):
+ def T(w):
+ return w.T if self.fan_in_fan_out else w
+ nn.Linear.eval(self)
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ def T(w):
+ return w.T if self.fan_in_fan_out else w
+ if self.r > 0 and not self.merged:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ if self.r > 0:
+ result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
+ return result
+ else:
+ return F.linear(x, T(self.weight), bias=self.bias)
+
+
+class MergedLinear(nn.Linear, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ enable_lora: List[bool] = [False],
+ fan_in_fan_out: bool = False,
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+ assert out_features % len(enable_lora) == 0, \
+ 'The length of enable_lora must divide out_features'
+ self.enable_lora = enable_lora
+ self.fan_in_fan_out = fan_in_fan_out
+ # Actual trainable parameters
+ if r > 0 and any(enable_lora):
+ self.lora_A = nn.Parameter(
+ self.weight.new_zeros((r * sum(enable_lora), in_features)))
+ self.lora_B = nn.Parameter(
+ self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
+ ) # weights for Conv1D with groups=sum(enable_lora)
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ # Compute the indices
+ self.lora_ind = self.weight.new_zeros(
+ (out_features, ), dtype=torch.bool
+ ).view(len(enable_lora), -1)
+ self.lora_ind[enable_lora, :] = True
+ self.lora_ind = self.lora_ind.view(-1)
+ self.reset_parameters()
+ if fan_in_fan_out:
+ self.weight.data = self.weight.data.T
+
+ def reset_parameters(self):
+ nn.Linear.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ def zero_pad(self, x):
+ result = x.new_zeros((*x.shape[:-1], self.out_features))
+ result = result.view(-1, self.out_features)
+ result[:, self.lora_ind] = x.reshape(
+ -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
+ )
+ return result.view((*x.shape[:-1], self.out_features))
+
+ def train(self, mode: bool = True):
+ def T(w):
+ return w.T if self.fan_in_fan_out else w
+ nn.Linear.train(self, mode)
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ if self.r > 0 and any(self.enable_lora):
+ delta_w = F.conv1d(
+ self.lora_A.data.unsqueeze(0),
+ self.lora_B.data.unsqueeze(-1),
+ groups=sum(self.enable_lora)
+ ).squeeze(0)
+ self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
+ self.merged = False
+
+ def eval(self):
+ def T(w):
+ return w.T if self.fan_in_fan_out else w
+ nn.Linear.eval(self)
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0 and any(self.enable_lora):
+ delta_w = F.conv1d(
+ self.lora_A.data.unsqueeze(0),
+ self.lora_B.data.unsqueeze(-1),
+ groups=sum(self.enable_lora)
+ ).squeeze(0)
+ self.weight.data += self.zero_pad(T(delta_w * self.scaling))
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ def T(w):
+ return w.T if self.fan_in_fan_out else w
+ if self.merged:
+ return F.linear(x, T(self.weight), bias=self.bias)
+ else:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ if self.r > 0:
+ after_A = F.linear(self.lora_dropout(x), self.lora_A)
+ after_B = F.conv1d(
+ after_A.transpose(-2, -1),
+ self.lora_B.unsqueeze(-1),
+ groups=sum(self.enable_lora)
+ ).transpose(-2, -1)
+ result += self.zero_pad(after_B) * self.scaling
+ return result
+
+
+class Conv2d(nn.Conv2d, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+ assert type(kernel_size) is int
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(
+ self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
+ )
+ self.lora_B = nn.Parameter(
+ self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
+ )
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.Conv2d.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ def train(self, mode: bool = True):
+ nn.Conv2d.train(self, mode)
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+ self.merged = False
+
+ def eval(self):
+ nn.Conv2d.eval(self)
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ if self.r > 0 and not self.merged:
+ return F.conv2d(
+ x,
+ self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
+ self.bias, self.stride, self.padding, self.dilation, self.groups
+ )
+ return nn.Conv2d.forward(self, x)
+
diff --git a/funasr/modules/lora/utils.py b/funasr/modules/lora/utils.py
new file mode 100644
index 0000000..e18bf44
--- /dev/null
+++ b/funasr/modules/lora/utils.py
@@ -0,0 +1,50 @@
+# ------------------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+# ------------------------------------------------------------------------------------------
+import torch
+import torch.nn as nn
+
+from typing import Dict
+
+from .layers import LoRALayer
+
+
+def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
+ for n, p in model.named_parameters():
+ if 'lora_' not in n and 'cif' not in n:
+ p.requires_grad = False
+ if bias == 'none':
+ return
+ elif bias == 'all':
+ for n, p in model.named_parameters():
+ if 'bias' in n:
+ p.requires_grad = True
+ elif bias == 'lora_only':
+ for m in model.modules():
+ if isinstance(m, LoRALayer) and \
+ hasattr(m, 'bias') and \
+ m.bias is not None:
+ m.bias.requires_grad = True
+ else:
+ raise NotImplementedError
+
+
+def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
+ my_state_dict = model.state_dict()
+ if bias == 'none':
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
+ elif bias == 'all':
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
+ elif bias == 'lora_only':
+ to_return = {}
+ for k in my_state_dict:
+ if 'lora_' in k:
+ to_return[k] = my_state_dict[k]
+ bias_name = k.split('lora_')[0]+'bias'
+ if bias_name in my_state_dict:
+ to_return[bias_name] = my_state_dict[bias_name]
+ return to_return
+ else:
+ raise NotImplementedError
+
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 91d33c5..f7f13d2 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -71,6 +71,7 @@
from funasr.utils.types import str_or_none
from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+from funasr.modules.lora.utils import mark_only_lora_as_trainable
try:
import wandb
@@ -952,6 +953,18 @@
default=None,
help="oss bucket.",
)
+ group.add_argument(
+ "--enable_lora",
+ type=str2bool,
+ default=False,
+ help="Apply lora for finetuning.",
+ )
+ group.add_argument(
+ "--lora_bias",
+ type=str,
+ default="none",
+ help="lora bias.",
+ )
cls.trainer.add_arguments(parser)
cls.add_task_arguments(parser)
@@ -1246,6 +1259,8 @@
dtype=getattr(torch, args.train_dtype),
device="cuda" if args.ngpu > 0 else "cpu",
)
+ if args.enable_lora:
+ mark_only_lora_as_trainable(model, args.lora_bias)
for t in args.freeze_param:
for k, p in model.named_parameters():
if k.startswith(t + ".") or k == t:
--
Gitblit v1.9.1