From a86d1676098f86444528646a409857ab02a4bbcb Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 21 七月 2023 15:26:44 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/bin/train.py                                                                                          |   15 +
 funasr/modules/attention.py                                                                                  |   44 +++
 funasr/modules/lora/layers.py                                                                                |  323 +++++++++++++++++++++++++++++
 egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md  |   55 ++++
 funasr/models/encoder/sanm_encoder.py                                                                        |   12 +
 funasr/modules/lora/__init__.py                                                                              |    0 
 funasr/modules/lora/utils.py                                                                                 |   50 ++++
 egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py |   18 +
 egs_modelscope/asr/TEMPLATE/README.md                                                                        |   55 +++-
 funasr/bin/build_trainer.py                                                                                  |   14 +
 funasr/tasks/abs_task.py                                                                                     |   15 +
 funasr/models/decoder/sanm_decoder.py                                                                        |    6 
 12 files changed, 566 insertions(+), 41 deletions(-)

diff --git a/egs_modelscope/asr/TEMPLATE/README.md b/egs_modelscope/asr/TEMPLATE/README.md
index 0219c5b..cf0ba84 100644
--- a/egs_modelscope/asr/TEMPLATE/README.md
+++ b/egs_modelscope/asr/TEMPLATE/README.md
@@ -1,6 +1,6 @@
 # Speech Recognition
 
-> **Note**: 
+> **Note**:
 > The modelscope pipeline supports all the models in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) to inference and finetine. Here we take the typic models as examples to demonstrate the usage.
 
 ## Inference
@@ -36,7 +36,7 @@
 param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
 chunk_stride = chunk_size[1] * 960 # 600ms銆�480ms
 # first chunk, 600ms
-speech_chunk = speech[0:chunk_stride] 
+speech_chunk = speech[0:chunk_stride]
 rec_result = inference_pipeline(audio_in=speech_chunk, param_dict=param_dict)
 print(rec_result)
 # next chunk, 600ms
@@ -101,16 +101,16 @@
 - `task`: `Tasks.auto_speech_recognition`
 - `model`: model name in [model zoo](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope), or model path in local disk
 - `ngpu`: `1` (Default), decoding on GPU. If ngpu=0, decoding on CPU
-- `ncpu`: `1` (Default), sets the number of threads used for intraop parallelism on CPU 
+- `ncpu`: `1` (Default), sets the number of threads used for intraop parallelism on CPU
 - `output_dir`: `None` (Default), the output path of results if set
 - `batch_size`: `1` (Default), batch size when decoding
 #### Infer pipeline
-- `audio_in`: the input to decode, which could be: 
+- `audio_in`: the input to decode, which could be:
   - wav_path, `e.g.`: asr_example.wav,
-  - pcm_path, `e.g.`: asr_example.pcm, 
+  - pcm_path, `e.g.`: asr_example.pcm,
   - audio bytes stream, `e.g.`: bytes data from a microphone
   - audio sample point锛宍e.g.`: `audio, rate = soundfile.read("asr_example_zh.wav")`, the dtype is numpy.ndarray or torch.Tensor
-  - wav.scp, kaldi style wav list (`wav_id \t wav_path`), `e.g.`: 
+  - wav.scp, kaldi style wav list (`wav_id \t wav_path`), `e.g.`:
   ```text
   asr_example1  ./audios/asr_example1.wav
   asr_example2  ./audios/asr_example2.wav
@@ -168,15 +168,19 @@
 [finetune.py](https://github.com/alibaba-damo-academy/FunASR/blob/main/egs_modelscope/asr/TEMPLATE/finetune.py)
 ```python
 import os
+
 from modelscope.metainfo import Trainers
 from modelscope.trainers import build_trainer
-from modelscope.msdatasets.audio.asr_dataset import ASRDataset
+
+from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
+
 
 def modelscope_finetune(params):
     if not os.path.exists(params.output_dir):
         os.makedirs(params.output_dir, exist_ok=True)
     # dataset split ["train", "validation"]
-    ds_dict = ASRDataset.load(params.data_path, namespace='speech_asr')
+    ds_dict = MsDataset.load(params.data_path)
     kwargs = dict(
         model=params.model,
         data_dir=ds_dict,
@@ -184,21 +188,32 @@
         work_dir=params.output_dir,
         batch_bins=params.batch_bins,
         max_epoch=params.max_epoch,
-        lr=params.lr)
+        lr=params.lr,
+        mate_params=params.param_dict)
     trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
     trainer.train()
 
 
 if __name__ == '__main__':
-    from funasr.utils.modelscope_param import modelscope_args
-    params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
-    params.output_dir = "./checkpoint"                      # 妯″瀷淇濆瓨璺緞
-    params.data_path = "speech_asr_aishell1_trainsets"      # 鏁版嵁璺緞锛屽彲浠ヤ负modelscope涓凡涓婁紶鏁版嵁锛屼篃鍙互鏄湰鍦版暟鎹�
-    params.dataset_type = "small"                           # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
-    params.batch_bins = 2000                                # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
-    params.max_epoch = 50                                   # 鏈�澶ц缁冭疆鏁�
-    params.lr = 0.00005                                     # 璁剧疆瀛︿範鐜�
-    
+    params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", data_path="./data")
+    params.output_dir = "./checkpoint"              # m妯″瀷淇濆瓨璺緞
+    params.data_path = "./example_data/"            # 鏁版嵁璺緞
+    params.dataset_type = "small"                   # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
+    params.batch_bins = 2000                       # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
+    params.max_epoch = 20                           # 鏈�澶ц缁冭疆鏁�
+    params.lr = 0.00005                             # 璁剧疆瀛︿範鐜�
+    init_param = []                                 # 鍒濆妯″瀷璺緞锛岄粯璁ゅ姞杞絤odelscope妯″瀷鍒濆鍖栵紝渚嬪: ["checkpoint/20epoch.pb"]
+    freeze_param = []                               # 妯″瀷鍙傛暟freeze, 渚嬪: ["encoder"]
+    ignore_init_mismatch = True                     # 鏄惁蹇界暐妯″瀷鍙傛暟鍒濆鍖栦笉鍖归厤
+    use_lora = False                                # 鏄惁浣跨敤lora杩涜妯″瀷寰皟
+    params.param_dict = {"init_param":init_param, "freeze_param": freeze_param, "ignore_init_mismatch": ignore_init_mismatch}
+    if use_lora:
+        enable_lora = True
+        lora_bias = "all"
+        lora_params = {"lora_list":['q','v'], "lora_rank":8, "lora_alpha":16, "lora_dropout":0.1}
+        lora_config = {"enable_lora": enable_lora, "lora_bias": lora_bias, "lora_params": lora_params}
+        params.param_dict.update(lora_config)
+
     modelscope_finetune(params)
 ```
 
@@ -215,6 +230,10 @@
     - `batch_bins`: batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
     - `max_epoch`: number of training epoch
     - `lr`: learning rate
+    - `init_param`: init model path, load modelscope model initialization by default. For example: ["checkpoint/20epoch.pb"]
+    - `freeze_param`: Freeze model parameters. For example锛歔"encoder"]
+    - `ignore_init_mismatch`: Ignore size mismatch when loading pre-trained model
+    - `use_lora`: Fine-tuning model use lora, more detail please refer to [LORA](https://arxiv.org/pdf/2106.09685.pdf)
 
 - Training data formats锛�
 ```sh
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
index 4e06daf..edc2cf1 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
@@ -2,14 +2,6 @@
 - Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary>
 - Model size: 220M
 
-# Environments
-- date: `Tue Nov 22 18:48:39 CST 2022`
-- python version: `3.7.12`
-- FunASR version: `0.1.0`
-- pytorch version: `pytorch 1.7.0`
-- Git hash: ``
-- Commit date: ``
-
 # Beachmark Results
 
 ## AISHELL-1
@@ -73,3 +65,50 @@
 |SPEECHIO_ASR_ZH000013| 2.57 | 2.25 |
 |SPEECHIO_ASR_ZH000014| 3.86 | 3.08 |
 |SPEECHIO_ASR_ZH000015| 3.34 | 2.67 |
+
+
+# Fine-tuning Results
+
+## Fine-tuning
+- Train config: 
+  - Training data: aishell-1
+  - Training info: lr 0.0002, dataset_type: small, batch bins 2000, 2 gpu, acc_grad 1, 20 epochs
+  - Decoding info: beam_size 1, average_num 10
+
+| model    | dev cer(%) | test cer(%) |
+|:---------:|:-------------:|:-------------:|
+| Pretrain       | 1.75          |1.95           |
+| Finetune      | 1.62          |1.78           |
+
+- Train config: 
+  - Training data: 16k sichuan dialect
+  - Training info: lr 0.0002, dataset_type: small, batch bins 2000, 2 gpu, acc_grad 1, 20 epochs
+  - Decoding info: beam_size 1, average_num 10
+  
+  
+|   model  | Training Data(h) | cn cer(%) | sichuan cer(%) |
+|:--------:|:-------------:|:-------:|:------------:|
+| Pretrain |               |   8.57  |     19.81    |
+| Finetune |      50      |   8.8   |      12      |
+|          |      100     |   9.24  |     11.63    |
+|          |      200     |   9.82  |     10.47    |
+|          |      300     |   9.95  |     10.44    |
+|          |     1000     |   9.99  |     9.78     |
+
+
+## Lora Fine-tuning
+- Train config: 
+  - Training data: 16k sichuan dialect
+  - Training info: lr 0.0002, dataset_type: small, batch bins 2000, 2 gpu, acc_grad 1, 20 epochs
+  - Lora info: lora_bias: "all", lora_list ['q','v'], lora_rank:8, lora_alpha:16, lora_dropout:0.1
+  - Decoding info: beam_size 1, average_num 10
+  
+|     model     | Training Data(h) | Trainable Parameters(M) | cn cer(%) | sichuan cer(%) |
+|:-------------:|:----------------:|:-----------------------:|:---------:|:--------------:|
+|    Pretrain   |                  |                         |    8.57   |      19.81     |
+|               |                  |                         |           |                |
+|    Finetune   |        50        |          220.9          |    8.8    |       12       |
+| Lora Finetune |        50        |           2.29          |    9.13   |      12.13     |
+|               |                  |                         |           |                |
+|    Finetune   |        200       |          220.9          |    9.82   |      10.47     |
+| Lora Finetune |        200       |           2.29          |    9.21   |      11.28     |
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py
index 1935258..993f8ed 100644
--- a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/finetune.py
@@ -19,7 +19,8 @@
         work_dir=params.output_dir,
         batch_bins=params.batch_bins,
         max_epoch=params.max_epoch,
-        lr=params.lr)
+        lr=params.lr,
+        mate_params=params.param_dict)
     trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
     trainer.train()
 
@@ -30,7 +31,18 @@
     params.data_path = "./example_data/"            # 鏁版嵁璺緞
     params.dataset_type = "small"                   # 灏忔暟鎹噺璁剧疆small锛岃嫢鏁版嵁閲忓ぇ浜�1000灏忔椂锛岃浣跨敤large
     params.batch_bins = 2000                       # batch size锛屽鏋渄ataset_type="small"锛宐atch_bins鍗曚綅涓篺bank鐗瑰緛甯ф暟锛屽鏋渄ataset_type="large"锛宐atch_bins鍗曚綅涓烘绉掞紝
-    params.max_epoch = 50                           # 鏈�澶ц缁冭疆鏁�
-    params.lr = 0.00005                             # 璁剧疆瀛︿範鐜�
+    params.max_epoch = 20                           # 鏈�澶ц缁冭疆鏁�
+    params.lr = 0.0002                             # 璁剧疆瀛︿範鐜�
+    init_param = []                                 # 鍒濆妯″瀷璺緞锛岄粯璁ゅ姞杞絤odelscope妯″瀷鍒濆鍖栵紝渚嬪: ["checkpoint/20epoch.pb"]
+    freeze_param = []                               # 妯″瀷鍙傛暟freeze, 渚嬪: ["encoder"]
+    ignore_init_mismatch = True                     # 鏄惁蹇界暐妯″瀷鍙傛暟鍒濆鍖栦笉鍖归厤
+    use_lora = False                                # 鏄惁浣跨敤lora杩涜妯″瀷寰皟
+    params.param_dict = {"init_param":init_param, "freeze_param": freeze_param, "ignore_init_mismatch": ignore_init_mismatch}
+    if use_lora:
+        enable_lora = True
+        lora_bias = "all"
+        lora_params = {"lora_list":['q','v'], "lora_rank":8, "lora_alpha":16, "lora_dropout":0.1}
+        lora_config = {"enable_lora": enable_lora, "lora_bias": lora_bias, "lora_params": lora_params}
+        params.param_dict.update(lora_config)
     
     modelscope_finetune(params)
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index 891139a..e7f28ed 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -2,7 +2,6 @@
 
 import yaml
 
-
 def update_dct(fin_configs, root):
     if root == {}:
         return {}
@@ -55,7 +54,7 @@
                   scheduler_conf=None,
                   specaug=None,
                   specaug_conf=None,
-                  param_dict=None,
+                  mate_params=None,
                   **kwargs):
     mode = modelscope_dict['mode']
     args, ASRTask = parse_args(mode=mode)
@@ -92,6 +91,14 @@
     for key, value in finetune_configs.items():
         if hasattr(args, key):
             setattr(args, key, value)
+    if mate_params is not None:
+        for key, value in mate_params.items():
+            if hasattr(args, key):
+                setattr(args, key, value)
+    if mate_params is not None and "lora_params" in mate_params:
+        lora_params = mate_params['lora_params']
+        configs['encoder_conf'].update(lora_params) 
+        configs['decoder_conf'].update(lora_params) 
 
     # prepare data
     args.dataset_type = dataset_type
@@ -106,6 +113,9 @@
     else:
         raise ValueError(f"Not supported dataset_type={args.dataset_type}")
     args.init_param = [init_param]
+    if mate_params is not None and "init_param" in mate_params:
+        if len(mate_params["init_param"]) != 0:
+            args.init_param = mate_params["init_param"]
     args.cmvn_file = cmvn_file
     if os.path.exists(seg_dict_file):
         args.seg_dict_file = seg_dict_file
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 1dc3fb5..f5d10c4 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -28,6 +28,7 @@
 from funasr.utils.types import str2bool
 from funasr.utils.types import str_or_none
 from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+from funasr.modules.lora.utils import mark_only_lora_as_trainable
 
 
 def get_parser():
@@ -478,6 +479,18 @@
         default=None,
         help="oss bucket.",
     )
+    parser.add_argument(
+        "--enable_lora",
+        type=str2bool,
+        default=False,
+        help="Apply lora for finetuning.",
+    )
+    parser.add_argument(
+        "--lora_bias",
+        type=str,
+        default="none",
+        help="lora bias.",
+    )
 
     return parser
 
@@ -521,6 +534,8 @@
         dtype=getattr(torch, args.train_dtype),
         device="cuda" if args.ngpu > 0 else "cpu",
     )
+    if args.enable_lora:
+        mark_only_lora_as_trainable(model, args.lora_bias)
     for t in args.freeze_param:
         for k, p in model.named_parameters():
             if k.startswith(t + ".") or k == t:
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index d83f89f..c12e098 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -833,6 +833,10 @@
         att_layer_num: int = 6,
         kernel_size: int = 21,
         sanm_shfit: int = 0,
+        lora_list: List[str] = None,
+        lora_rank: int = 8,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.1,
         tf2torch_tensor_name_prefix_torch: str = "decoder",
         tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
     ):
@@ -885,7 +889,7 @@
                     attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                 ),
                 MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate
+                    attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
                 ),
                 PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 45163df..9e27d4a 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -146,6 +146,10 @@
         interctc_use_conditioning: bool = False,
         kernel_size : int = 11,
         sanm_shfit : int = 0,
+        lora_list: List[str] = None,
+        lora_rank: int = 8,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.1,
         selfattention_layer_type: str = "sanm",
         tf2torch_tensor_name_prefix_torch: str = "encoder",
         tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
@@ -229,6 +233,10 @@
                 attention_dropout_rate,
                 kernel_size,
                 sanm_shfit,
+                lora_list,
+                lora_rank,
+                lora_alpha,
+                lora_dropout,
             )
 
             encoder_selfattn_layer_args = (
@@ -238,6 +246,10 @@
                 attention_dropout_rate,
                 kernel_size,
                 sanm_shfit,
+                lora_list,
+                lora_rank,
+                lora_alpha,
+                lora_dropout,
             )
         self.encoders0 = repeat(
             1,
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index fcb3ed4..ab59493 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -15,6 +15,7 @@
 
 import torch.nn.functional as F
 from funasr.modules.nets_utils import make_pad_mask
+import funasr.modules.lora.layers as lora
 
 class MultiHeadedAttention(nn.Module):
     """Multi-Head Attention layer.
@@ -321,7 +322,7 @@
 
     """
 
-    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
         """Construct an MultiHeadedAttention object."""
         super(MultiHeadedAttentionSANM, self).__init__()
         assert n_feat % n_head == 0
@@ -331,8 +332,19 @@
         # self.linear_q = nn.Linear(n_feat, n_feat)
         # self.linear_k = nn.Linear(n_feat, n_feat)
         # self.linear_v = nn.Linear(n_feat, n_feat)
-        self.linear_out = nn.Linear(n_feat, n_feat)
-        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+        if lora_list is not None:
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
+            lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
+            if lora_qkv_list == [False, False, False]:
+                self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+            else:
+                self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+        else:
+            self.linear_out = nn.Linear(n_feat, n_feat)
+            self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
         self.attn = None
         self.dropout = nn.Dropout(p=dropout_rate)
 
@@ -543,18 +555,32 @@
 
     """
 
-    def __init__(self, n_head, n_feat, dropout_rate, encoder_output_size=None):
+    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
         """Construct an MultiHeadedAttention object."""
         super(MultiHeadedAttentionCrossAtt, self).__init__()
         assert n_feat % n_head == 0
         # We assume d_v always equals d_k
         self.d_k = n_feat // n_head
         self.h = n_head
-        self.linear_q = nn.Linear(n_feat, n_feat)
-        # self.linear_k = nn.Linear(n_feat, n_feat)
-        # self.linear_v = nn.Linear(n_feat, n_feat)
-        self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
-        self.linear_out = nn.Linear(n_feat, n_feat)
+        if lora_list is not None:
+            if "q" in lora_list:
+                self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_q = nn.Linear(n_feat, n_feat)
+            lora_kv_list = ["k" in lora_list, "v" in lora_list]
+            if lora_kv_list == [False, False]:
+                self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            else:
+                self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
+                                      r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
+        else:
+            self.linear_q = nn.Linear(n_feat, n_feat)
+            self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            self.linear_out = nn.Linear(n_feat, n_feat)
         self.attn = None
         self.dropout = nn.Dropout(p=dropout_rate)
 
diff --git a/funasr/modules/lora/__init__.py b/funasr/modules/lora/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/modules/lora/__init__.py
diff --git a/funasr/modules/lora/layers.py b/funasr/modules/lora/layers.py
new file mode 100644
index 0000000..76f046c
--- /dev/null
+++ b/funasr/modules/lora/layers.py
@@ -0,0 +1,323 @@
+#  ------------------------------------------------------------------------------------------
+#  Copyright (c) Microsoft Corporation. All rights reserved.
+#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+#  ------------------------------------------------------------------------------------------
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import math
+from typing import Optional, List
+
+class LoRALayer():
+    def __init__(
+        self,
+        r: int,
+        lora_alpha: int,
+        lora_dropout: float,
+        merge_weights: bool,
+    ):
+        self.r = r
+        self.lora_alpha = lora_alpha
+        # Optional dropout
+        if lora_dropout > 0.:
+            self.lora_dropout = nn.Dropout(p=lora_dropout)
+        else:
+            self.lora_dropout = lambda x: x
+        # Mark the weight as unmerged
+        self.merged = False
+        self.merge_weights = merge_weights
+
+
+class Embedding(nn.Embedding, LoRALayer):
+    # LoRA implemented in a dense layer
+    def __init__(
+        self,
+        num_embeddings: int,
+        embedding_dim: int,
+        r: int = 0,
+        lora_alpha: int = 1,
+        merge_weights: bool = True,
+        **kwargs
+    ):
+        nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
+        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
+                           merge_weights=merge_weights)
+        # Actual trainable parameters
+        if r > 0:
+            self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
+            self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
+            self.scaling = self.lora_alpha / self.r
+            # Freezing the pre-trained weight matrix
+            self.weight.requires_grad = False
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        nn.Embedding.reset_parameters(self)
+        if hasattr(self, 'lora_A'):
+            # initialize A the same way as the default for nn.Linear and B to zero
+            nn.init.zeros_(self.lora_A)
+            nn.init.normal_(self.lora_B)
+
+    def train(self, mode: bool = True):
+        nn.Embedding.train(self, mode)
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            if self.r > 0:
+                self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
+            self.merged = False
+
+    def eval(self):
+        nn.Linear.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            if self.r > 0:
+                self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
+            self.merged = True
+
+    def forward(self, x: torch.Tensor):
+        if self.r > 0 and not self.merged:
+            result = nn.Embedding.forward(self, x)
+            if self.r > 0:
+                after_A = F.embedding(
+                    x, self.lora_A.T, self.padding_idx, self.max_norm,
+                    self.norm_type, self.scale_grad_by_freq, self.sparse
+                )
+                result += (after_A @ self.lora_B.T) * self.scaling
+            return result
+        else:
+            return nn.Embedding.forward(self, x)
+
+
+class Linear(nn.Linear, LoRALayer):
+    # LoRA implemented in a dense layer
+    def __init__(
+        self,
+        in_features: int,
+        out_features: int,
+        r: int = 0,
+        lora_alpha: int = 1,
+        lora_dropout: float = 0.,
+        fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+        merge_weights: bool = True,
+        **kwargs
+    ):
+        nn.Linear.__init__(self, in_features, out_features, **kwargs)
+        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+                           merge_weights=merge_weights)
+
+        self.fan_in_fan_out = fan_in_fan_out
+        # Actual trainable parameters
+        if r > 0:
+            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
+            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
+            self.scaling = self.lora_alpha / self.r
+            # Freezing the pre-trained weight matrix
+            self.weight.requires_grad = False
+        self.reset_parameters()
+        if fan_in_fan_out:
+            self.weight.data = self.weight.data.T
+
+    def reset_parameters(self):
+        nn.Linear.reset_parameters(self)
+        if hasattr(self, 'lora_A'):
+            # initialize A the same way as the default for nn.Linear and B to zero
+            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+            nn.init.zeros_(self.lora_B)
+
+    def train(self, mode: bool = True):
+        def T(w):
+            return w.T if self.fan_in_fan_out else w
+        nn.Linear.train(self, mode)
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            if self.r > 0:
+                self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+            self.merged = False
+
+    def eval(self):
+        def T(w):
+            return w.T if self.fan_in_fan_out else w
+        nn.Linear.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            if self.r > 0:
+                self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+            self.merged = True
+
+    def forward(self, x: torch.Tensor):
+        def T(w):
+            return w.T if self.fan_in_fan_out else w
+        if self.r > 0 and not self.merged:
+            result = F.linear(x, T(self.weight), bias=self.bias)
+            if self.r > 0:
+                result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
+            return result
+        else:
+            return F.linear(x, T(self.weight), bias=self.bias)
+
+
+class MergedLinear(nn.Linear, LoRALayer):
+    # LoRA implemented in a dense layer
+    def __init__(
+        self,
+        in_features: int,
+        out_features: int,
+        r: int = 0,
+        lora_alpha: int = 1,
+        lora_dropout: float = 0.,
+        enable_lora: List[bool] = [False],
+        fan_in_fan_out: bool = False,
+        merge_weights: bool = True,
+        **kwargs
+    ):
+        nn.Linear.__init__(self, in_features, out_features, **kwargs)
+        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+                           merge_weights=merge_weights)
+        assert out_features % len(enable_lora) == 0, \
+            'The length of enable_lora must divide out_features'
+        self.enable_lora = enable_lora
+        self.fan_in_fan_out = fan_in_fan_out
+        # Actual trainable parameters
+        if r > 0 and any(enable_lora):
+            self.lora_A = nn.Parameter(
+                self.weight.new_zeros((r * sum(enable_lora), in_features)))
+            self.lora_B = nn.Parameter(
+                self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
+            ) # weights for Conv1D with groups=sum(enable_lora)
+            self.scaling = self.lora_alpha / self.r
+            # Freezing the pre-trained weight matrix
+            self.weight.requires_grad = False
+            # Compute the indices
+            self.lora_ind = self.weight.new_zeros(
+                (out_features, ), dtype=torch.bool
+            ).view(len(enable_lora), -1)
+            self.lora_ind[enable_lora, :] = True
+            self.lora_ind = self.lora_ind.view(-1)
+        self.reset_parameters()
+        if fan_in_fan_out:
+            self.weight.data = self.weight.data.T
+
+    def reset_parameters(self):
+        nn.Linear.reset_parameters(self)
+        if hasattr(self, 'lora_A'):
+            # initialize A the same way as the default for nn.Linear and B to zero
+            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+            nn.init.zeros_(self.lora_B)
+
+    def zero_pad(self, x):
+        result = x.new_zeros((*x.shape[:-1], self.out_features))
+        result = result.view(-1, self.out_features)
+        result[:, self.lora_ind] = x.reshape(
+            -1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
+        )
+        return result.view((*x.shape[:-1], self.out_features))
+
+    def train(self, mode: bool = True):
+        def T(w):
+            return w.T if self.fan_in_fan_out else w
+        nn.Linear.train(self, mode)
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            if self.r > 0 and any(self.enable_lora):
+                delta_w = F.conv1d(
+                    self.lora_A.data.unsqueeze(0),
+                    self.lora_B.data.unsqueeze(-1),
+                    groups=sum(self.enable_lora)
+                ).squeeze(0)
+                self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
+            self.merged = False
+
+    def eval(self):
+        def T(w):
+            return w.T if self.fan_in_fan_out else w
+        nn.Linear.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            if self.r > 0 and any(self.enable_lora):
+                delta_w = F.conv1d(
+                    self.lora_A.data.unsqueeze(0),
+                    self.lora_B.data.unsqueeze(-1),
+                    groups=sum(self.enable_lora)
+                ).squeeze(0)
+                self.weight.data += self.zero_pad(T(delta_w * self.scaling))
+            self.merged = True
+
+    def forward(self, x: torch.Tensor):
+        def T(w):
+            return w.T if self.fan_in_fan_out else w
+        if self.merged:
+            return F.linear(x, T(self.weight), bias=self.bias)
+        else:
+            result = F.linear(x, T(self.weight), bias=self.bias)
+            if self.r > 0:
+                after_A = F.linear(self.lora_dropout(x), self.lora_A)
+                after_B = F.conv1d(
+                    after_A.transpose(-2, -1),
+                    self.lora_B.unsqueeze(-1),
+                    groups=sum(self.enable_lora)
+                ).transpose(-2, -1)
+                result += self.zero_pad(after_B) * self.scaling
+            return result
+
+
+class Conv2d(nn.Conv2d, LoRALayer):
+    # LoRA implemented in a dense layer
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        r: int = 0,
+        lora_alpha: int = 1,
+        lora_dropout: float = 0.,
+        merge_weights: bool = True,
+        **kwargs
+    ):
+        nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
+        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+                           merge_weights=merge_weights)
+        assert type(kernel_size) is int
+        # Actual trainable parameters
+        if r > 0:
+            self.lora_A = nn.Parameter(
+                self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
+            )
+            self.lora_B = nn.Parameter(
+                self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
+            )
+            self.scaling = self.lora_alpha / self.r
+            # Freezing the pre-trained weight matrix
+            self.weight.requires_grad = False
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        nn.Conv2d.reset_parameters(self)
+        if hasattr(self, 'lora_A'):
+            # initialize A the same way as the default for nn.Linear and B to zero
+            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+            nn.init.zeros_(self.lora_B)
+
+    def train(self, mode: bool = True):
+        nn.Conv2d.train(self, mode)
+        if self.merge_weights and self.merged:
+            # Make sure that the weights are not merged
+            self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+            self.merged = False
+
+    def eval(self):
+        nn.Conv2d.eval(self)
+        if self.merge_weights and not self.merged:
+            # Merge the weights and mark it
+            self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+            self.merged = True
+
+    def forward(self, x: torch.Tensor):
+        if self.r > 0 and not self.merged:
+            return F.conv2d(
+                x,
+                self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
+                self.bias, self.stride, self.padding, self.dilation, self.groups
+            )
+        return nn.Conv2d.forward(self, x)
+
diff --git a/funasr/modules/lora/utils.py b/funasr/modules/lora/utils.py
new file mode 100644
index 0000000..e18bf44
--- /dev/null
+++ b/funasr/modules/lora/utils.py
@@ -0,0 +1,50 @@
+#  ------------------------------------------------------------------------------------------
+#  Copyright (c) Microsoft Corporation. All rights reserved.
+#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
+#  ------------------------------------------------------------------------------------------
+import torch
+import torch.nn as nn
+
+from typing import Dict
+
+from .layers import LoRALayer
+
+
+def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
+    for n, p in model.named_parameters():
+        if 'lora_' not in n and 'cif' not in n:
+            p.requires_grad = False
+    if bias == 'none':
+        return
+    elif bias == 'all':
+        for n, p in model.named_parameters():
+            if 'bias' in n:
+                p.requires_grad = True
+    elif bias == 'lora_only':
+        for m in model.modules():
+            if isinstance(m, LoRALayer) and \
+                hasattr(m, 'bias') and \
+                m.bias is not None:
+                    m.bias.requires_grad = True
+    else:
+        raise NotImplementedError
+
+
+def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
+    my_state_dict = model.state_dict()
+    if bias == 'none':
+        return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
+    elif bias == 'all':
+        return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
+    elif bias == 'lora_only':
+        to_return = {}
+        for k in my_state_dict:
+            if 'lora_' in k:
+                to_return[k] = my_state_dict[k]
+                bias_name = k.split('lora_')[0]+'bias'
+                if bias_name in my_state_dict:
+                    to_return[bias_name] = my_state_dict[bias_name]
+        return to_return
+    else:
+        raise NotImplementedError
+
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 91d33c5..f7f13d2 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -71,6 +71,7 @@
 from funasr.utils.types import str_or_none
 from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
 from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+from funasr.modules.lora.utils import mark_only_lora_as_trainable
 
 try:
     import wandb
@@ -952,6 +953,18 @@
             default=None,
             help="oss bucket.",
         )
+        group.add_argument(
+            "--enable_lora",
+            type=str2bool,
+            default=False,
+            help="Apply lora for finetuning.",
+        )
+        group.add_argument(
+            "--lora_bias",
+            type=str,
+            default="none",
+            help="lora bias.",
+        )
 
         cls.trainer.add_arguments(parser)
         cls.add_task_arguments(parser)
@@ -1246,6 +1259,8 @@
             dtype=getattr(torch, args.train_dtype),
             device="cuda" if args.ngpu > 0 else "cpu",
         )
+        if args.enable_lora:
+            mark_only_lora_as_trainable(model, args.lora_bias)
         for t in args.freeze_param:
             for k, p in model.named_parameters():
                 if k.startswith(t + ".") or k == t:

--
Gitblit v1.9.1