From 0a4a1d5257dace9561d95b38a9386539908dcd5e Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 23 四月 2024 12:48:52 +0800
Subject: [PATCH] Dev gzf exp (#1645)

---
 funasr/models/conformer_rwkv/decoder.py             |  507 ++++++++++++
 funasr/models/sense_voice/decoder.py                |  273 ++++++
 funasr/datasets/audio_datasets/espnet_samplers.py   |   11 
 funasr/models/conformer_rwkv/model.py               |   19 
 funasr/models/sense_voice/rwkv_v6.py                |  425 ++++++++++
 funasr/models/conformer_rwkv/__init__.py            |    0 
 funasr/datasets/audio_datasets/index_ds.py          |  202 +++-
 funasr/models/sense_voice/model.py                  |  211 +++++
 funasr/models/conformer_rwkv/template.yaml          |  123 +++
 funasr/models/sense_voice/cuda/wkv5_op.cpp          |   22 
 funasr/models/sense_voice/cuda/wkv6_op.cpp          |   23 
 funasr/bin/train.py                                 |    5 
 funasr/models/sense_voice/cuda/wkv5_cuda.cu         |  202 +++++
 examples/aishell/conformer/conf/conformer_rwkv.yaml |  124 +++
 funasr/models/sense_voice/whisper_lib/model.py      |    4 
 funasr/models/sense_voice/cuda/wkv6_cuda.cu         |  243 ++++++
 16 files changed, 2,321 insertions(+), 73 deletions(-)

diff --git a/examples/aishell/conformer/conf/conformer_rwkv.yaml b/examples/aishell/conformer/conf/conformer_rwkv.yaml
new file mode 100644
index 0000000..4742838
--- /dev/null
+++ b/examples/aishell/conformer/conf/conformer_rwkv.yaml
@@ -0,0 +1,124 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: Conformer
+model_conf:
+    ctc_weight: 0.3
+    lsm_weight: 0.1     # label smoothing option
+    length_normalized_loss: false
+
+# encoder
+encoder: ConformerEncoder
+encoder_conf:
+    output_size: 256    # dimension of attention
+    attention_heads: 4
+    linear_units: 2048  # the number of units of position-wise feed forward
+    num_blocks: 12      # the number of encoder blocks
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.0
+    input_layer: conv2d # encoder architecture type
+    normalize_before: true
+    pos_enc_layer_type: rel_pos
+    selfattention_layer_type: rel_selfattn
+    activation_type: swish
+    macaron_style: true
+    use_cnn_module: true
+    cnn_module_kernel: 15
+
+# decoder
+decoder: TransformerRWKVDecoder
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 6
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.0
+    src_attention_dropout_rate: 0.0
+    input_layer: embed
+    rwkv_cfg:
+      n_embd: 256
+      dropout: 0
+      head_size_a: 64
+      ctx_len: 512
+      dim_att: 256 #${model_conf.rwkv_cfg.n_embd}
+      dim_ffn: null
+      head_size_divisor: 4
+      n_layer: 6
+      pre_ffn: 0
+      ln0: false
+      ln1: false
+      init_rwkv: true
+
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
+specaug: SpecAug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 40
+    num_time_mask: 2
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 150
+  keep_nbest_models: 10
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: EspnetStyleBatchSampler
+    batch_type: length # example or length
+    batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 1024
+    shuffle: True
+    num_workers: 4
+    preprocessor_speech: SpeechPreprocessSpeedPerturb
+    preprocessor_speech_conf:
+      speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+normalize: null
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 4ab2d8a..ab49c82 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -90,7 +90,8 @@
     # freeze_param
     freeze_param = kwargs.get("freeze_param", None)
     if freeze_param is not None:
-        freeze_param = eval(freeze_param)
+        if "," in freeze_param:
+            freeze_param = eval(freeze_param)
         if isinstance(freeze_param, Sequence):
             freeze_param = (freeze_param,)
         logging.info("freeze_param is not None: %s", freeze_param)
@@ -104,7 +105,7 @@
     if use_ddp:
         model = model.cuda(local_rank)
         model = DDP(model, device_ids=[local_rank],
-                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True))
+                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
     elif use_fsdp:
         # model = FSDP(model).cuda(local_rank)
 
diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index bca0753..4bb34f3 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -32,8 +32,9 @@
     def __init__(self, dataset,
                  batch_size,
                  batch_type="token",
-                 num_replicas=None,
                  rank=None,
+                 num_replicas=None,
+                 rank_split=False,
                  shuffle=True,
                  drop_last=False,
                  is_training: bool = True,
@@ -45,6 +46,10 @@
             rank = dist.get_rank()
             num_replicas = dist.get_world_size()
         except:
+            rank = 0
+            num_replicas = 1
+        if rank_split:
+            logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank")
             rank = 0
             num_replicas = 1
         self.rank = rank
@@ -65,8 +70,8 @@
         self.length_scale_source = kwargs.get("length_scale_source", 1.0)
 
 
-        super().__init__(dataset, num_replicas=num_replicas, rank=rank,
-                         shuffle=shuffle, drop_last=drop_last)
+        # super().__init__(dataset, num_replicas=num_replicas, rank=rank,
+        #                  shuffle=shuffle, drop_last=drop_last)
     def __iter__(self):
         if self.shuffle:
             g = torch.Generator()
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index 53419e8..3270531 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -9,66 +9,66 @@
 from funasr.register import tables
 
 
-@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
-class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
-    
-    def __init__(self, path):
-        super().__init__()
-        
-        contents = []
-        with open(path, encoding='utf-8') as fin:
-            for line in fin:
-                data = json.loads(line.strip())
-                if "text" in data:  # for sft
-                    self.contents.append(data['text'])
-                if "source" in data:  # for speech lab pretrain
-                    prompt = data["prompt"]
-                    source = data["source"]
-                    target = data["target"]
-                    source_len = data["source_len"]
-                    target_len = data["target_len"]
-
-                    contents.append({"source": source,
-                                     "prompt": prompt,
-                                     "target": target,
-                                     "source_len": source_len,
-                                     "target_len": target_len,
-                                     }
-                                    )
-        
-        self.contents = []
-        total_num = len(contents)
-        try:
-            rank = dist.get_rank()
-            world_size = dist.get_world_size()
-        except:
-            rank = 0
-            world_size = 1
-            logging.warning("distributed is not initialized, only single shard")
-        num_per_rank = total_num // world_size
-        
-        # rank = 0
-        # import ipdb; ipdb.set_trace()
-        self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
-    
-        logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
-
-    def __len__(self):
-        return len(self.contents)
-    
-    def __getitem__(self, index):
-        try:
-            data = self.contents[index]
-        except:
-            print(index)
-        return data
-    
-    def get_source_len(self, data_dict):
-        return data_dict["source_len"]
-
-    def get_target_len(self, data_dict):
-        
-        return data_dict["target_len"] if "target_len" in data_dict else 0
+# @tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
+# class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
+#
+#     def __init__(self, path):
+#         super().__init__()
+#
+#         contents = []
+#         with open(path, encoding='utf-8') as fin:
+#             for line in fin:
+#                 data = json.loads(line.strip())
+#                 if "text" in data:  # for sft
+#                     self.contents.append(data['text'])
+#                 if "source" in data:  # for speech lab pretrain
+#                     prompt = data["prompt"]
+#                     source = data["source"]
+#                     target = data["target"]
+#                     source_len = data["source_len"]
+#                     target_len = data["target_len"]
+#
+#                     contents.append({"source": source,
+#                                      "prompt": prompt,
+#                                      "target": target,
+#                                      "source_len": source_len,
+#                                      "target_len": target_len,
+#                                      }
+#                                     )
+#
+#         self.contents = []
+#         total_num = len(contents)
+#         try:
+#             rank = dist.get_rank()
+#             world_size = dist.get_world_size()
+#         except:
+#             rank = 0
+#             world_size = 1
+#             logging.warning("distributed is not initialized, only single shard")
+#         num_per_rank = total_num // world_size
+#
+#         # rank = 0
+#         # import ipdb; ipdb.set_trace()
+#         self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
+#
+#         logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
+#
+#     def __len__(self):
+#         return len(self.contents)
+#
+#     def __getitem__(self, index):
+#         try:
+#             data = self.contents[index]
+#         except:
+#             print(index)
+#         return data
+#
+#     def get_source_len(self, data_dict):
+#         return data_dict["source_len"]
+#
+#     def get_target_len(self, data_dict):
+#
+#         return data_dict["target_len"] if "target_len" in data_dict else 0
 
 @tables.register("index_ds_classes", "IndexDSJsonl")
 @tables.register("index_ds_classes", "IndexDSJsonlRankFull")
@@ -143,3 +143,85 @@
     def get_target_len(self, data_dict):
         
         return data_dict.get("target_len", 0)
+
+
+@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
+class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
+    
+    def __init__(self, path: str, **kwargs):
+        super().__init__()
+        self.max_source_length = kwargs.get("max_source_length", 2048)
+        self.min_source_length = kwargs.get("min_source_length", 0)
+        self.max_target_length = kwargs.get("max_target_length", 2048)
+        self.min_target_length = kwargs.get("min_target_length", 0)
+
+        with open(path, encoding='utf-8') as fin:
+            file_list = fin.readlines()
+
+        total_num = len(file_list)
+        try:
+            rank = dist.get_rank()
+            world_size = dist.get_world_size()
+        except:
+            rank = 0
+            world_size = 1
+            logging.warning("distributed is not initialized, only single shard")
+        num_per_rank = total_num // world_size
+        if num_per_rank * world_size < total_num:
+            logging.warning(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}")
+
+        file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank]
+
+        contents = []
+        for file_json in file_list_rank:
+            
+            with open(file_json.strip(), encoding='utf-8') as fin:
+                for line in fin:
+                    data = json.loads(line.strip())
+                    if "text" in data:  # for sft
+                        contents.append(data['text'])
+                    if "source" in data:  # for speech lab pretrain
+                        prompt = data.get("prompt", "<ASR>")
+                        source = data["source"].replace("/cpfs01", "/cpfs_speech/data")
+                        target = data["target"]
+                        source_len = data.get("source_len", 1)
+                        target_len = data.get("target_len", 0)
+                        
+                        if source_len < self.min_source_length or source_len > self.max_source_length:
+                            continue
+                        if target_len < self.min_target_length or target_len > self.max_target_length:
+                            continue
+                        contents_i = {"source": source,
+                                      "prompt": prompt,
+                                      "target": target,
+                                      "source_len": source_len,
+                                      "target_len": target_len,
+                                      }
+                        text_language = data.get("text_language", None)
+                        if text_language is not None:
+                            contents_i["text_language"] = text_language
+                        audio_language = data.get("audio_language", None)
+                        if audio_language is not None:
+                            contents_i["audio_language"] = audio_language
+                        contents.append(contents_i)
+        
+        self.contents = contents
+        
+        logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}")
+    
+    def __len__(self):
+        return len(self.contents)
+    
+    def __getitem__(self, index):
+        try:
+            data = self.contents[index]
+        except:
+            print(index)
+        return data
+    
+    def get_source_len(self, data_dict):
+        return data_dict.get("source_len", 1)
+    
+    def get_target_len(self, data_dict):
+        
+        return data_dict.get("target_len", 0)
diff --git a/funasr/models/conformer_rwkv/__init__.py b/funasr/models/conformer_rwkv/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/conformer_rwkv/__init__.py
diff --git a/funasr/models/conformer_rwkv/decoder.py b/funasr/models/conformer_rwkv/decoder.py
new file mode 100644
index 0000000..d7f113d
--- /dev/null
+++ b/funasr/models/conformer_rwkv/decoder.py
@@ -0,0 +1,507 @@
+# Copyright 2019 Shigeki Karita
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Decoder definition."""
+from typing import Any
+from typing import List
+from typing import Sequence
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
+from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.lightconv import LightweightConvolution
+from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
+from funasr.models.transformer.utils.mask import subsequent_mask
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.positionwise_feed_forward import (
+    PositionwiseFeedForward,  # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
+from omegaconf import OmegaConf
+from funasr.register import tables
+
+class DecoderLayer(nn.Module):
+    """Single decoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        src_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+    """
+
+    def __init__(
+            self,
+            size,
+            self_attn,
+            src_attn,
+            feed_forward,
+            dropout_rate,
+            normalize_before=True,
+            concat_after=False,
+            layer_id=None,
+            args={},
+    ):
+        """Construct an DecoderLayer object."""
+        super(DecoderLayer, self).__init__()
+        self.size = size
+        self.self_attn = self_attn.to(torch.bfloat16)
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        self.norm2 = LayerNorm(size)
+        self.norm3 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+        self.layer_id = layer_id
+        self.ln0 = None
+        if self.layer_id == 0 and not args.get("ln0", True):
+            self.ln0 = LayerNorm(args.n_embd)
+            if args.get("init_rwkv", True):
+                print("init_rwkv")
+                layer_id = 0
+                scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+                nn.init.constant_(self.ln0.weight, scale)
+            
+        # init
+        if args.get("init_rwkv", True):
+            print("init_rwkv")
+            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+            nn.init.constant_(self.norm1.weight, scale)
+            nn.init.constant_(self.self_attn.ln2.weight, scale)
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+
+        if self.layer_id == 0 and self.ln0 is not None:
+            tgt = self.ln0(tgt)
+            
+        residual = tgt
+        
+        
+        tgt = self.norm1(tgt)
+
+        if cache is None:
+    
+            x = residual + self.dropout(self.self_attn(tgt, mask=tgt_mask))
+        else:
+            
+            # tgt_q = tgt[:, -1:, :]
+            # residual_q = residual[:, -1:, :]
+            tgt_q_mask = None
+            
+            x = residual + self.dropout(self.self_attn(tgt, mask=tgt_q_mask))
+            x = x[:, -1, :]
+
+
+
+        # x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+        
+
+        residual = x
+        x = self.norm2(x)
+        x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
+        residual = x
+        x = self.norm3(x)
+        x = residual + self.dropout(self.feed_forward(x))
+
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+
+        return x, tgt_mask, memory, memory_mask
+
+
+
+class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
+    """Base class of Transfomer decoder module.
+
+    Args:
+        vocab_size: output dim
+        encoder_output_size: dimension of attention
+        attention_heads: the number of heads of multi head attention
+        linear_units: the number of units of position-wise feed forward
+        num_blocks: the number of decoder blocks
+        dropout_rate: dropout rate
+        self_attention_dropout_rate: dropout rate for attention
+        input_layer: input layer type
+        use_output_layer: whether to use output layer
+        pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+        normalize_before: whether to use layer_norm before the first block
+        concat_after: whether to concat attention layer's input and output
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied.
+            i.e. x -> x + att(x)
+    """
+
+    def __init__(
+            self,
+            vocab_size: int,
+            encoder_output_size: int,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            input_layer: str = "embed",
+            use_output_layer: bool = True,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+    ):
+        super().__init__()
+        attention_dim = encoder_output_size
+
+        if input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(vocab_size, attention_dim),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        elif input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(vocab_size, attention_dim),
+                torch.nn.LayerNorm(attention_dim),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        else:
+            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+        self.normalize_before = normalize_before
+        if self.normalize_before:
+            self.after_norm = LayerNorm(attention_dim)
+        if use_output_layer:
+            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+        else:
+            self.output_layer = None
+
+        # Must set by the inheritance
+        self.decoders = None
+
+    def forward(
+            self,
+            hs_pad: torch.Tensor,
+            hlens: torch.Tensor,
+            ys_in_pad: torch.Tensor,
+            ys_in_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        tgt = ys_in_pad
+        # tgt_mask: (B, 1, L)
+        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+        # m: (1, L, L)
+        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+        # tgt_mask: (B, L, L)
+        tgt_mask = tgt_mask & m
+
+        memory = hs_pad
+        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
+            memory.device
+        )
+        # Padding for Longformer
+        if memory_mask.shape[-1] != memory.shape[1]:
+            padlen = memory.shape[1] - memory_mask.shape[-1]
+            memory_mask = torch.nn.functional.pad(
+                memory_mask, (0, padlen), "constant", False
+            )
+
+        x = self.embed(tgt)
+        x, tgt_mask, memory, memory_mask = self.decoders(
+            x, tgt_mask, memory, memory_mask
+        )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+
+        olens = tgt_mask.sum(1)
+        return x, olens
+
+    def forward_one_step(
+            self,
+            tgt: torch.Tensor,
+            tgt_mask: torch.Tensor,
+            memory: torch.Tensor,
+            cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Forward one step.
+
+        Args:
+            tgt: input token ids, int64 (batch, maxlen_out)
+            tgt_mask: input token mask,  (batch, maxlen_out)
+                      dtype=torch.uint8 in PyTorch 1.2-
+                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+            memory: encoded memory, float32  (batch, maxlen_in, feat)
+            cache: cached output list of (batch, max_time_out-1, size)
+        Returns:
+            y, cache: NN output value and cache per `self.decoders`.
+            y.shape` is (batch, maxlen_out, token)
+        """
+        x = self.embed(tgt)
+        if cache is None:
+            cache = [None] * len(self.decoders)
+        new_cache = []
+        for c, decoder in zip(cache, self.decoders):
+            x, tgt_mask, memory, memory_mask = decoder(
+                x, tgt_mask, memory, None, cache=c
+            )
+            new_cache.append(x)
+
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.output_layer is not None:
+            y = torch.log_softmax(self.output_layer(y), dim=-1)
+
+        return y, new_cache
+
+    def score(self, ys, state, x):
+        """Score."""
+        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
+        logp, state = self.forward_one_step(
+            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
+        )
+        return logp.squeeze(0), state
+
+    def batch_score(
+            self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+    ) -> Tuple[torch.Tensor, List[Any]]:
+        """Score new token batch.
+
+        Args:
+            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+            states (List[Any]): Scorer states for prefix tokens.
+            xs (torch.Tensor):
+                The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+        Returns:
+            tuple[torch.Tensor, List[Any]]: Tuple of
+                batchfied scores for next token with shape of `(n_batch, n_vocab)`
+                and next state list for ys.
+
+        """
+        # merge states
+        n_batch = len(ys)
+        n_layers = len(self.decoders)
+        if states[0] is None:
+            batch_state = None
+        else:
+            # transpose state of [batch, layer] into [layer, batch]
+            batch_state = [
+                torch.stack([states[b][i] for b in range(n_batch)])
+                for i in range(n_layers)
+            ]
+
+        # batch decoding
+        ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
+        logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
+
+        # transpose state of [layer, batch] into [batch, layer]
+        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
+        return logp, state_list
+
+@tables.register("decoder_classes", "TransformerRWKVDecoder")
+class TransformerRWKVDecoder(BaseTransformerDecoder):
+    def __init__(
+            self,
+            vocab_size: int,
+            encoder_output_size: int,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            self_attention_dropout_rate: float = 0.0,
+            src_attention_dropout_rate: float = 0.0,
+            input_layer: str = "embed",
+            use_output_layer: bool = True,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            **kwargs,
+    ):
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+        from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+        rwkv_cfg = kwargs.get("rwkv_cfg", {})
+        args = OmegaConf.create(rwkv_cfg)
+        # self.attn = RWKVLayer(args=args, layer_id=layer_id)
+        attention_dim = encoder_output_size
+        self.decoders = repeat(
+            num_blocks,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                RWKVLayer(args=args, layer_id=lnum),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+                lnum,
+                args=args,
+            ),
+        )
+        
+        # init
+        if args.get("init_rwkv", True):
+            print("init_rwkv")
+            nn.init.uniform_(self.embed[0].weight, a=-1e-4, b=1e-4)
+
+    def forward(
+            self,
+            hs_pad: torch.Tensor,
+            hlens: torch.Tensor,
+            ys_in_pad: torch.Tensor,
+            ys_in_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        tgt = ys_in_pad
+        # tgt_mask: (B, 1, L)
+        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+        # m: (1, L, L)
+        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+        # tgt_mask: (B, L, L)
+        tgt_mask = tgt_mask & m
+
+        memory = hs_pad
+        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
+            memory.device
+        )
+        # Padding for Longformer
+        if memory_mask.shape[-1] != memory.shape[1]:
+            padlen = memory.shape[1] - memory_mask.shape[-1]
+            memory_mask = torch.nn.functional.pad(
+                memory_mask, (0, padlen), "constant", False
+            )
+
+        x = self.embed(tgt)
+        x, tgt_mask, memory, memory_mask = self.decoders(
+            x, tgt_mask, memory, memory_mask
+        )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+
+        olens = tgt_mask.sum(1)
+        return x, olens
+
+    def forward_one_step(
+            self,
+            tgt: torch.Tensor,
+            tgt_mask: torch.Tensor,
+            memory: torch.Tensor,
+            cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Forward one step.
+
+        Args:
+            tgt: input token ids, int64 (batch, maxlen_out)
+            tgt_mask: input token mask,  (batch, maxlen_out)
+                      dtype=torch.uint8 in PyTorch 1.2-
+                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+            memory: encoded memory, float32  (batch, maxlen_in, feat)
+            cache: cached output list of (batch, max_time_out-1, size)
+        Returns:
+            y, cache: NN output value and cache per `self.decoders`.
+            y.shape` is (batch, maxlen_out, token)
+        """
+        x = self.embed(tgt)
+        if cache is None:
+            cache = [None] * len(self.decoders)
+        new_cache = []
+        for c, decoder in zip(cache, self.decoders):
+            x, tgt_mask, memory, memory_mask = decoder(
+                x, tgt_mask, memory, None, cache=c
+            )
+            new_cache.append(x)
+
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.output_layer is not None:
+            y = torch.log_softmax(self.output_layer(y), dim=-1)
+
+        return y, new_cache
\ No newline at end of file
diff --git a/funasr/models/conformer_rwkv/model.py b/funasr/models/conformer_rwkv/model.py
new file mode 100644
index 0000000..171014b
--- /dev/null
+++ b/funasr/models/conformer_rwkv/model.py
@@ -0,0 +1,19 @@
+import logging
+
+import torch
+
+from funasr.models.transformer.model import Transformer
+from funasr.register import tables
+
+@tables.register("model_classes", "Conformer")
+class Conformer(Transformer):
+    """CTC-attention hybrid Encoder-Decoder model"""
+
+    
+    def __init__(
+        self,
+        *args,
+        **kwargs,
+    ):
+
+        super().__init__(*args, **kwargs)
diff --git a/funasr/models/conformer_rwkv/template.yaml b/funasr/models/conformer_rwkv/template.yaml
new file mode 100644
index 0000000..cd71105
--- /dev/null
+++ b/funasr/models/conformer_rwkv/template.yaml
@@ -0,0 +1,123 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: Conformer
+model_conf:
+    ctc_weight: 0.3
+    lsm_weight: 0.1     # label smoothing option
+    length_normalized_loss: false
+
+# encoder
+encoder: ConformerEncoder
+encoder_conf:
+    output_size: 256    # dimension of attention
+    attention_heads: 4
+    linear_units: 2048  # the number of units of position-wise feed forward
+    num_blocks: 12      # the number of encoder blocks
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.0
+    input_layer: conv2d # encoder architecture type
+    normalize_before: true
+    pos_enc_layer_type: rel_pos
+    selfattention_layer_type: rel_selfattn
+    activation_type: swish
+    macaron_style: true
+    use_cnn_module: true
+    cnn_module_kernel: 15
+
+# decoder
+decoder: TransformerRWKVDecoder
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 6
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.0
+    src_attention_dropout_rate: 0.0
+    input_layer: embed
+    rwkv_cfg:
+      n_embd: 256
+      dropout: 0
+      head_size_a: 64
+      ctx_len: 512
+      dim_att: 256 #${model_conf.rwkv_cfg.n_embd}
+      dim_ffn: null
+      head_size_divisor: 4
+      n_layer: 6
+      pre_ffn: 0
+      ln0: false
+      ln1: false
+
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
+specaug: SpecAug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 40
+    num_time_mask: 2
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 150
+  keep_nbest_models: 10
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: EspnetStyleBatchSampler
+    batch_type: length # example or length
+    batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 1024
+    shuffle: True
+    num_workers: 4
+    preprocessor_speech: SpeechPreprocessSpeedPerturb
+    preprocessor_speech_conf:
+      speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+normalize: null
diff --git a/funasr/models/sense_voice/cuda/wkv5_cuda.cu b/funasr/models/sense_voice/cuda/wkv5_cuda.cu
new file mode 100644
index 0000000..3e6b859
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv5_cuda.cu
@@ -0,0 +1,202 @@
+#include <stdio.h>
+#include <assert.h>
+#include "ATen/ATen.h"
+typedef at::BFloat16 bf16;
+
+template <typename F>
+__global__ void kernel_forward(const int B, const int T, const int C, const int H,
+                               const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
+                               F *__restrict__ const _y)
+{
+    const int b = blockIdx.x / H;
+    const int h = blockIdx.x % H;
+    const int i = threadIdx.x;
+    _w += h*_N_;
+    _u += h*_N_;
+
+    __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
+    float state[_N_] = {0};
+
+    __syncthreads();
+    w[i] = _w[i];
+    u[i] = float(_u[i]);
+    __syncthreads();
+
+    for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
+    {
+        __syncthreads();
+        r[i] = float(_r[t]);
+        k[i] = float(_k[t]);
+        __syncthreads();
+
+        const float v = float(_v[t]);
+        float y = 0;
+
+        #pragma unroll
+        for (int j = 0; j < _N_; j+=4)
+        {
+            const float4& r_ = (float4&)(r[j]);
+            const float4& k_ = (float4&)(k[j]);
+            const float4& w_ = (float4&)(w[j]);
+            const float4& u_ = (float4&)(u[j]);
+            float4& s = (float4&)(state[j]);
+            float4 x;
+
+            x.x = k_.x * v;
+            x.y = k_.y * v;
+            x.z = k_.z * v;
+            x.w = k_.w * v;
+
+            y += r_.x * (u_.x * x.x + s.x);
+            y += r_.y * (u_.y * x.y + s.y);
+            y += r_.z * (u_.z * x.z + s.z);
+            y += r_.w * (u_.w * x.w + s.w);
+
+            s.x = s.x * w_.x + x.x;
+            s.y = s.y * w_.y + x.y;
+            s.z = s.z * w_.z + x.z;
+            s.w = s.w * w_.w + x.w;
+        }
+        _y[t] = F(y);
+    }
+}
+
+template <typename F>
+__global__ void kernel_backward(const int B, const int T, const int C, const int H,
+    const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
+    F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
+{
+    const int b = blockIdx.x / H;
+    const int h = blockIdx.x % H;
+    const int i = threadIdx.x;
+    _w += h*_N_;
+    _u += h*_N_;
+    __w += h*_N_;
+
+    __shared__ float w_[_N_], u_[_N_];
+    __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_];
+    __syncthreads();
+    w_[i] = _w[i];
+    u_[i] = float(_u[i]);
+    __syncthreads();
+
+    const float w = w_[i];
+    const float ww = __w[i];
+    const float u = u_[i];
+
+    float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
+
+    float gw = 0, gu = 0;
+    const int t000 = b*T*C + h*_N_ + i;
+    const int t111 = (b+1)*T*C + h*_N_ + i;
+    const int t222 = t111 - 2*C;
+
+    for (int t = t000; t < t111; t += C)
+    {
+        __syncthreads();
+        v[i] = float(_v[t]);
+        gy[i] = float(_gy[t]);
+        __syncthreads();
+
+        const float k = float(_k[t]);
+        float gr = 0, gu_ = 0;
+
+        #pragma unroll
+        for (int j = 0; j < _N_; j++)
+        {
+            float& s = state[j];
+            float x = k * v[j];
+
+            gr += (u * x + s) * gy[j];
+            gu_ += x * gy[j];
+            s = s * w + x;
+        }
+        _gr[t] = F(gr);
+        gu += float(_r[t]) * gu_;
+    }
+    _gu[b*C + h*_N_ + i] = F(gu);
+    
+    for (int t = t000; t < t222; t += C)
+    {
+        __syncthreads();
+        v[i] = float(_v[t]);
+        gy[i] = float(_gy[t + 2*C]);
+        __syncthreads();
+
+        const float k = float(_k[t]);
+        float gw_ = 0;
+        
+        #pragma unroll
+        for (int j = 0; j < _N_; j++)
+        {
+            float& s = saaaa[j];
+            float& s2 = sbbbb[j];
+            float x = k * v[j];
+            
+            float tmp = w * (x + s);
+            s = tmp;
+            s2 = tmp + w * s2;
+            gw_ += s2 * gy[j];
+        }
+        gw += float(_r[t + 2*C]) * gw_;
+    }    
+    _gw[b*C + h*_N_ + i] = F(ww * gw);
+
+    for (int t = t111 - C; t >= t000; t -= C)
+    {
+        __syncthreads();
+        v[i] = float(_v[t]);
+        gy[i] = float(_gy[t]);
+        __syncthreads();
+
+        const float rr = float(_r[t]);
+        float gk = 0;
+
+        #pragma unroll
+        for (int j = 0; j < _N_; j++)
+        {
+            float& s = scccc[j];
+            float x = rr * gy[j];
+            
+            gk += (u * x + s) * v[j];
+            s = x + s * w;
+        }
+        _gk[t] = F(gk);
+    }
+
+    for (int t = t111 - C; t >= t000; t -= C)
+    {
+        __syncthreads();
+        r[i] = float(_r[t]);
+        k[i] = float(_k[t]);
+        __syncthreads();
+
+        const float gyy = float(_gy[t]);
+        float gv = 0;
+
+        #pragma unroll
+        for (int j = 0; j < _N_; j++)
+        {
+            float& s = sdddd[j];
+            float x = gyy * r[j];
+            
+            gv += (u_[j] * x + s) * k[j];
+            s = x + s * w_[j];
+        }
+        _gv[t] = F(gv);
+    }
+}
+
+void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
+{
+    assert(H*_N_ == C);
+    assert(_N_%4 == 0);
+    kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
+}
+
+void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
+{
+    assert(H*_N_ == C);
+    assert(_N_%4 == 0);
+    kernel_backward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
+}
diff --git a/funasr/models/sense_voice/cuda/wkv5_op.cpp b/funasr/models/sense_voice/cuda/wkv5_op.cpp
new file mode 100644
index 0000000..4c9ece1
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv5_op.cpp
@@ -0,0 +1,22 @@
+#include <torch/extension.h>
+#include "ATen/ATen.h"
+typedef at::BFloat16 bf16;
+
+void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
+void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
+
+void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
+    cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
+}
+void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
+    cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
+}
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("forward", &forward, "wkv5 forward");
+    m.def("backward", &backward, "wkv5 backward");
+}
+
+TORCH_LIBRARY(wkv5, m) {
+    m.def("forward", forward);
+    m.def("backward", backward);
+}
diff --git a/funasr/models/sense_voice/cuda/wkv6_cuda.cu b/funasr/models/sense_voice/cuda/wkv6_cuda.cu
new file mode 100644
index 0000000..d98f57f
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv6_cuda.cu
@@ -0,0 +1,243 @@
+#include <stdio.h>
+#include <assert.h>
+#include "ATen/ATen.h"
+typedef at::BFloat16 bf16;
+// typedef float bf16;
+
+template <typename F>
+__global__ void kernel_forward(const int B, const int T, const int C, const int H,
+                               const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
+                               F *__restrict__ const _y)
+{
+    const int b = blockIdx.x / H;
+    const int h = blockIdx.x % H;
+    const int i = threadIdx.x;
+    _u += h*_N_;
+
+    __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
+    float state[_N_] = {0};
+
+    __syncthreads();
+    u[i] = float(_u[i]);
+    __syncthreads();
+
+    for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
+    {
+        __syncthreads();
+        w[i] = exp(_w[t]);
+        r[i] = float(_r[t]);
+        k[i] = float(_k[t]);
+        __syncthreads();
+
+        const float v = float(_v[t]);
+        float y = 0;
+
+        #pragma unroll
+        for (int j = 0; j < _N_; j+=4)
+        {
+            const float4& r_ = (float4&)(r[j]);
+            const float4& k_ = (float4&)(k[j]);
+            const float4& w_ = (float4&)(w[j]);
+            const float4& u_ = (float4&)(u[j]);
+            float4& s = (float4&)(state[j]);
+            float4 x;
+
+            x.x = k_.x * v;
+            x.y = k_.y * v;
+            x.z = k_.z * v;
+            x.w = k_.w * v;
+
+            y += r_.x * (u_.x * x.x + s.x);
+            y += r_.y * (u_.y * x.y + s.y);
+            y += r_.z * (u_.z * x.z + s.z);
+            y += r_.w * (u_.w * x.w + s.w);
+
+            s.x = s.x * w_.x + x.x;
+            s.y = s.y * w_.y + x.y;
+            s.z = s.z * w_.z + x.z;
+            s.w = s.w * w_.w + x.w;
+        }
+        _y[t] = F(y);
+    }
+}
+
+template <typename F>
+__global__ void kernel_backward_111(const int B, const int T, const int C, const int H,
+    const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
+    F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu)
+{
+    const int b = blockIdx.x / H;
+    const int h = blockIdx.x % H;
+    const int i = threadIdx.x;
+    _u += h*_N_;
+
+    __shared__ float u_[_N_];
+    __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_];
+    __syncthreads();
+    u_[i] = float(_u[i]);
+    __syncthreads();
+
+    const float u = u_[i];
+
+    float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
+
+    const int t_0 = b*T*C + h*_N_ + i;
+    const int t_T_1 = t_0 + (T-1)*C;
+    const int t_T = t_0 + T*C;
+
+    float gu = 0;
+    for (int t = t_0; t < t_T; t += C)
+    {
+        __syncthreads();
+        v[i] = float(_v[t]);
+        gy[i] = float(_gy[t]);
+        __syncthreads();
+
+        const float k = float(_k[t]);
+        const float w = exp(_w[t]);
+        float gr = 0, gu_ = 0;
+
+        #pragma unroll
+        for (int j = 0; j < _N_; j++)
+        {
+            float& s = state[j];
+            float x = k * v[j];
+
+            gr += (u * x + s) * gy[j];
+            gu_ += x * gy[j];
+            s = s * w + x;
+        }
+        _gr[t] = F(gr);
+        gu += float(_r[t]) * gu_;
+    }
+    _gu[b*C + h*_N_ + i] = F(gu);
+
+    for (int t = t_T_1; t >= t_0; t -= C)
+    {
+        __syncthreads();
+        v[i] = float(_v[t]);
+        gy[i] = float(_gy[t]);
+        __syncthreads();
+
+        const float rr = float(_r[t]);
+        const float w = exp(_w[t]);
+        float gk = 0;
+
+        #pragma unroll
+        for (int j = 0; j < _N_; j++)
+        {
+            float& s = scccc[j];
+            float x = rr * gy[j];
+            
+            gk += (u * x + s) * v[j];
+            s = x + s * w;
+        }
+        _gk[t] = F(gk);
+    }
+
+    for (int t = t_T_1; t >= t_0; t -= C)
+    {
+        __syncthreads();
+        r[i] = float(_r[t]);
+        k[i] = float(_k[t]);
+        w_[i] = exp(_w[t]);
+        __syncthreads();
+
+        const float gyy = float(_gy[t]);
+        float gv = 0;
+
+        #pragma unroll
+        for (int j = 0; j < _N_; j++)
+        {
+            float& s = sdddd[j];
+            float x = gyy * r[j];
+            
+            gv += (u_[j] * x + s) * k[j];
+            s = x + s * w_[j];
+        }
+        _gv[t] = F(gv);
+    }
+}
+
+template <typename F>
+__global__ void kernel_backward_222(const int B, const int T, const int C, const int H,
+    const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
+    F *__restrict__ const _gw)
+{
+    const int b = blockIdx.x / H;
+    const int h = blockIdx.x % H;
+    const int i = threadIdx.x;
+
+    __shared__ float v[_N_], gy[_N_];
+    float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0};
+
+    const int t_0 = b*T*C + h*_N_ + i;
+    const int t_1 = t_0 + C;
+    const int t_2 = t_0 + 2*C;
+    const int t_T_1 = t_0 + (T-1)*C;
+
+    for (int t = t_T_1; t > t_1; t -= C)
+    {
+        __syncthreads();
+        gy[i] = float(_gy[t]);
+        v[i] = float(_v[t-2*C]);
+        __syncthreads();
+
+        const float r = float(_r[t]);
+        const float w = exp(_w[t-C]);
+        float sum = 0.0f;
+
+        #pragma unroll
+        for (int j = 0; j < _N_; j++)
+        {
+            float& s = saaaa[j];
+            float x = r * gy[j];
+            s = (s + x) * w;
+            sum += s * v[j];
+        }
+        sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]);
+    }
+
+    float sss = sbbbb[0];
+    _gw[t_0] = 0;
+    _gw[t_1] = F(sss * _w[t_1]);
+
+    for (int t = t_2; t < t_T_1; t += C)
+    {
+        __syncthreads();
+        gy[i] = float(_gy[t]);
+        v[i] = float(_v[t-2*C]);
+        __syncthreads();
+
+        const float w = exp(_w[t-C]);
+        const float k = float(_k[t-2*C]);
+        float sum = 0.0f;
+
+        #pragma unroll
+        for (int j = 0; j < _N_; j++)
+        {
+            float& s = scccc[j];
+            float x = k * v[j];
+            s = (s + x) * w;
+            sum += s * gy[j];
+        }
+        sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t]));
+        _gw[t] = F(sss * _w[t]);
+    }
+    _gw[t_T_1] = 0;
+}
+
+void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
+{
+    assert(H*_N_ == C);
+    assert(_N_%4 == 0);
+    kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
+}
+
+void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
+{
+    assert(H*_N_ == C);
+    assert(_N_%4 == 0);
+    kernel_backward_111<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu);
+    kernel_backward_222<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gw);
+}
diff --git a/funasr/models/sense_voice/cuda/wkv6_op.cpp b/funasr/models/sense_voice/cuda/wkv6_op.cpp
new file mode 100644
index 0000000..22da520
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv6_op.cpp
@@ -0,0 +1,23 @@
+#include <torch/extension.h>
+#include "ATen/ATen.h"
+ typedef at::BFloat16 bf16;
+//typedef float bf16;
+
+void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
+void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
+
+void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
+    cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
+}
+void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
+    cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
+}
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("forward", &forward, "wkv6 forward");
+    m.def("backward", &backward, "wkv6 backward");
+}
+
+TORCH_LIBRARY(wkv6, m) {
+    m.def("forward", forward);
+    m.def("backward", backward);
+}
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index bae2832..9087ea1 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -5,6 +5,32 @@
 import torch.nn as nn
 import torch.nn.functional as F
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.register import tables
+import base64
+import gzip
+from dataclasses import dataclass
+from typing import Dict, Iterable, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class LayerNorm(nn.LayerNorm):
+    def forward(self, x: Tensor) -> Tensor:
+        return super().forward(x.float()).type(x.dtype)
+
+
+class Linear(nn.Linear):
+    def forward(self, x: Tensor) -> Tensor:
+        return F.linear(
+            x,
+            self.weight.to(x.dtype),
+            None if self.bias is None else self.bias.to(x.dtype),
+        )
+
+
 
 def sense_voice_decode_forward(
 	self,
@@ -38,10 +64,10 @@
 	
 	offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
 	tgt, memory = x, xa
-	tgt[tgt==-1] = 0
+	tgt[tgt == -1] = 0
 	tgt = (
 		self.token_embedding(tgt)
-		+ self.positional_embedding[offset : offset + tgt.size(1)]
+		+ self.positional_embedding[offset: offset + tgt.size(1)]
 	)
 	# tgt = self.dropout(tgt)
 	
@@ -54,13 +80,248 @@
 	
 	for layer, block in enumerate(self.blocks):
 		x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
-
-
+	
 	x = self.ln(x)
 	x = (
 		x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
 	).float()
 	
-	
 	return x
-	
\ No newline at end of file
+
+
+class MultiHeadAttention(nn.Module):
+	def __init__(self, n_state: int, n_head: int):
+		super().__init__()
+		self.n_head = n_head
+		self.query = Linear(n_state, n_state)
+		self.key = Linear(n_state, n_state, bias=False)
+		self.value = Linear(n_state, n_state)
+		self.out = Linear(n_state, n_state)
+	
+	def forward(
+		self,
+		x: Tensor,
+		xa: Optional[Tensor] = None,
+		mask: Optional[Tensor] = None,
+		kv_cache: Optional[dict] = None,
+		**kwargs,
+	):
+		is_pad_mask = kwargs.get("is_pad_mask", False)
+		
+		q = self.query(x)
+		
+		if kv_cache is None or xa is None or self.key not in kv_cache:
+			# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+			# otherwise, perform key/value projections for self- or cross-attention as usual.
+			k = self.key(x if xa is None else xa)
+			v = self.value(x if xa is None else xa)
+		else:
+			# for cross-attention, calculate keys and values once and reuse in subsequent calls.
+			k = kv_cache[self.key]
+			v = kv_cache[self.value]
+		
+		wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
+		return self.out(wv), qk
+	
+	def qkv_attention(
+		self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
+	):
+		is_pad_mask = kwargs.get("is_pad_mask", False)
+		n_batch, n_ctx, n_state = q.shape
+		scale = (n_state // self.n_head) ** -0.25
+		q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
+		k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
+		v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
+		
+		qk = q @ k
+		if mask is not None:
+			if not is_pad_mask:
+				qk = qk + mask[:n_ctx, :n_ctx]
+			else:
+				mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+				min_value = float(
+					np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
+				)
+				qk = qk.masked_fill(mask, min_value)
+		
+		qk = qk.float()
+		
+		w = F.softmax(qk, dim=-1).to(q.dtype)
+		if mask is not None and is_pad_mask:
+			w = w.masked_fill(mask, 0.0)
+		return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
+
+
+from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+from omegaconf import OmegaConf
+
+
+class ResidualAttentionBlockRWKV(nn.Module):
+	def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, layer_id=0, **kwargs):
+		super().__init__()
+		
+		rwkv_cfg = kwargs.get("rwkv_cfg", {})
+		args = OmegaConf.create(rwkv_cfg)
+		self.attn = RWKVLayer(args=args, layer_id=layer_id)
+		if args.get("datatype", "bf16") == "bf16":
+			self.attn.to(torch.bfloat16)
+			
+		self.ln0 = None
+		if layer_id == 0 and not args.get("ln0", True):
+			self.ln0 = LayerNorm(args.n_embd)
+			if args.get("init_rwkv", True):
+				print("init_rwkv")
+				layer_id = 0
+				scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+				nn.init.constant_(self.ln0.weight, scale)
+		self.layer_id = layer_id
+		self.args = args
+		
+		self.ln1 = None
+		if not args.get("ln1", True):
+			self.ln1 = LayerNorm(args.n_embd)
+			# init
+			if args.get("init_rwkv", True):
+				print("init_rwkv")
+				scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+				nn.init.constant_(self.ln1.weight, scale)
+		
+		self.cross_attn = (
+			MultiHeadAttention(n_state, n_head) if cross_attention else None
+		)
+		self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+		
+		n_mlp = n_state * 4
+		self.mlp = nn.Sequential(
+			Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
+		)
+		self.mlp_ln = LayerNorm(n_state)
+	
+	def forward(
+		self,
+		x: Tensor,
+		xa: Optional[Tensor] = None,
+		mask: Optional[Tensor] = None,
+		kv_cache: Optional[dict] = None,
+		**kwargs,
+	):
+		is_pad_mask = kwargs.get("is_pad_mask", False)
+		is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
+		
+		if self.layer_id == 0 and self.ln0 is not None:
+			x = self.ln0(x)
+		
+		if self.ln1 is None:
+			x = x + self.attn(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+		else:
+			x = x + self.attn(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+	
+		if self.cross_attn:
+			x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
+		x = x + self.mlp(self.mlp_ln(x))
+		
+		return x
+
+@tables.register("decoder_classes", "SenseVoiceDecoder")
+class SenseVoiceDecoder(nn.Module):
+	def __init__(
+		self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs
+	):
+		super().__init__()
+		
+		self.token_embedding = nn.Embedding(n_vocab, n_state)
+		self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+
+
+		self.blocks = nn.ModuleList(
+			[
+				ResidualAttentionBlockRWKV(n_state, n_head, cross_attention=True, layer_id=i, **kwargs)
+				for i in range(n_layer)
+			]
+		)
+		self.ln = LayerNorm(n_state)
+		
+		mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
+		self.register_buffer("mask", mask, persistent=False)
+		
+		self.use_padmask = kwargs.get("use_padmask", True)
+	# def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
+	# 	"""
+	# 	x : torch.LongTensor, shape = (batch_size, <= n_ctx)
+	# 		the text tokens
+	# 	xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
+	# 		the encoded audio features to be attended on
+	# 	"""
+	# 	offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+	# 	x = (
+	# 		self.token_embedding(x)
+	# 		+ self.positional_embedding[offset: offset + x.shape[-1]]
+	# 	)
+	# 	x = x.to(xa.dtype)
+	#
+	# 	for block in self.blocks:
+	# 		x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
+	#
+	# 	x = self.ln(x)
+	# 	logits = (
+	# 		x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+	# 	).float()
+	#
+	# 	return logits
+
+	
+	def forward(
+		self,
+		x: torch.Tensor,
+		xa: torch.Tensor,
+		kv_cache: Optional[dict] = None,
+		**kwargs,
+	):
+		"""Forward decoder.
+	
+		Args:
+			hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+			hlens: (batch)
+			ys_in_pad:
+				input token ids, int64 (batch, maxlen_out)
+				if input_layer == "embed"
+				input tensor (batch, maxlen_out, #mels) in the other cases
+			ys_in_lens: (batch)
+		Returns:
+			(tuple): tuple containing:
+	
+			x: decoded token score before softmax (batch, maxlen_out, token)
+				if use_output_layer is True,
+			olens: (batch, )
+		"""
+		# import pdb;pdb.set_trace()
+		use_padmask = self.use_padmask
+		hlens = kwargs.get("hlens", None)
+		
+		ys_in_lens = kwargs.get("ys_in_lens", None)
+		
+		offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+		tgt, memory = x, xa
+		tgt[tgt == -1] = 0
+		tgt = (
+			self.token_embedding(tgt)
+			+ self.positional_embedding[offset: offset + tgt.size(1)]
+		)
+		# tgt = self.dropout(tgt)
+		
+		x = tgt.to(memory.dtype)
+		
+		if use_padmask and hlens is not None:
+			memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
+		else:
+			memory_mask = None
+		
+		for layer, block in enumerate(self.blocks):
+			x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
+		
+		x = self.ln(x)
+		x = (
+			x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+		).float()
+		
+		return x
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index b5272a1..fa1c047 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -226,4 +226,213 @@
         results.append(result_i)
     
         return results, meta_data
-    
\ No newline at end of file
+
+
+@tables.register("model_classes", "SenseVoiceRWKV")
+class SenseVoiceRWKV(nn.Module):
+    def __init__(self, *args, **kwargs):
+        super().__init__()
+        
+        dims = kwargs.get("dims", {})
+        dims = whisper.model.ModelDimensions(**dims)
+        model = whisper.model.Whisper(dims=dims)
+        
+        # encoder
+        model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
+        model.encoder.use_padmask = kwargs.get("use_padmask", True)
+        from .encoder import sense_voice_encode_forward
+        model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
+        
+        # decoder
+        del model.decoder
+        decoder = kwargs.get("decoder", "SenseVoiceDecoder")
+        decoder_class = tables.decoder_classes.get(decoder)
+        decoder = decoder_class(n_vocab=dims.n_vocab,
+                                n_ctx=dims.n_text_ctx,
+                                n_state=dims.n_text_state,
+                                n_head=dims.n_text_head,
+                                n_layer=dims.n_text_layer,
+                                **kwargs.get("decoder_conf"))
+        model.decoder = decoder
+        
+        self.model = model
+        
+        self.encoder_output_size = self.model.dims.n_audio_state
+        
+        self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
+        self.ignore_id = kwargs.get("ignore_id", -1)
+        self.vocab_size = kwargs.get("vocab_size", -1)
+        self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
+        self.criterion_att = LabelSmoothingLoss(
+            size=self.vocab_size,
+            padding_idx=self.ignore_id,
+            smoothing=kwargs.get("lsm_weight", 0.0),
+            normalize_length=self.length_normalized_loss,
+        )
+        
+        specaug = kwargs.get("specaug", None)
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**kwargs.get("specaug_conf", {}))
+        self.specaug = specaug
+    
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ):
+        target_mask = kwargs.get("target_mask", None)
+        
+        # import pdb;
+        # pdb.set_trace()
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+        
+        batch_size = speech.shape[0]
+        
+        if self.activation_checkpoint:
+            from torch.utils.checkpoint import checkpoint
+            encoder_out, encoder_out_lens = checkpoint(self.encode, speech, speech_lengths, use_reentrant=False)
+        else:
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        
+        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+            encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
+        )
+        loss = loss_att
+        stats = {}
+        stats["acc"] = acc_att
+        stats["loss"] = torch.clone(loss.detach())
+        stats["batch_size"] = batch_size
+        
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+    
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ):
+        """Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+        
+        # Forward encoder
+        encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
+        
+        return encoder_out, encoder_out_lens
+    
+    def _calc_att_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+        **kwargs,
+    ):
+        target_mask = kwargs.get("target_mask", None)
+        stats = {}
+        
+        # 1. Forward decoder
+        decoder_out = self.model.decoder(
+            x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
+        )
+        
+        # 2. Compute attention loss
+        mask = torch.ones_like(ys_pad) * (-1)
+        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
+        ys_pad_mask[ys_pad_mask == 0] = -1
+        loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
+        
+        with torch.no_grad():
+            preds = torch.argmax(decoder_out, -1)
+            acc_att = compute_accuracy(preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id)
+        
+        return loss_att, acc_att, None, None
+    
+    def inference(self,
+                  data_in,
+                  data_lengths=None,
+                  key: list = None,
+                  tokenizer=None,
+                  frontend=None,
+                  **kwargs,
+                  ):
+        if kwargs.get("batch_size", 1) > 1:
+            raise NotImplementedError("batch decoding is not implemented")
+        
+        if frontend is None and not hasattr(self, "frontend"):
+            frontend_class = tables.frontend_classes.get("WhisperFrontend")
+            frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
+            self.frontend = frontend
+        else:
+            frontend = frontend if frontend is not None else self.frontend
+        
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(data_in,
+                                                            fs=frontend.fs if hasattr(frontend, "fs") else 16000,
+                                                            audio_fs=kwargs.get("fs", 16000),
+                                                            data_type=kwargs.get("data_type", "sound"),
+                                                            tokenizer=tokenizer)
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
+            lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
+        
+        speech = speech.to(device=kwargs["device"])[0, :, :]
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+        
+        DecodingOptions = kwargs.get("DecodingOptions", {})
+        task = DecodingOptions.get("task", "ASR")
+        if isinstance(task, str):
+            task = [task]
+        task = "".join([f"<|{x}|>" for x in task])
+        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
+        DecodingOptions["initial_prompt"] = initial_prompt
+        
+        language = DecodingOptions.get("language", None)
+        language = None if language == "auto" else language
+        DecodingOptions["language"] = language
+        
+        DecodingOptions["vocab_path"] = kwargs.get("vocab_path", None)
+        
+        if "without_timestamps" not in DecodingOptions:
+            DecodingOptions["without_timestamps"] = True
+        
+        options = whisper.DecodingOptions(**DecodingOptions)
+        
+        result = whisper.decode(self.model, speech, options)
+        text = f"{result.text}"
+        results = []
+        result_i = {"key": key[0], "text": text}
+        
+        results.append(result_i)
+        
+        return results, meta_data
diff --git a/funasr/models/sense_voice/rwkv_v6.py b/funasr/models/sense_voice/rwkv_v6.py
new file mode 100644
index 0000000..6eb53fc
--- /dev/null
+++ b/funasr/models/sense_voice/rwkv_v6.py
@@ -0,0 +1,425 @@
+########################################################################################################
+# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
+########################################################################################################
+
+import os, math, gc, importlib
+import torch
+# torch._C._jit_set_profiling_executor(True)
+# torch._C._jit_set_profiling_mode(True)
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+
+def __nop(ob):
+	return ob
+
+
+MyModule = nn.Module
+MyFunction = __nop
+if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
+	MyModule = torch.jit.ScriptModule
+	MyFunction = torch.jit.script_method
+
+########################################################################################################
+# CUDA Kernel
+########################################################################################################
+
+wkv6_cuda = None
+
+def load_rwkv_kernel(HEAD_SIZE: int=64, RWKV_CTXLEN: int=512,):
+	from torch.utils.cpp_extension import load
+	global wkv6_cuda
+	
+	
+	if wkv6_cuda is not None:
+		return
+	
+	absolute_file_path = os.path.abspath(__file__)
+	cur_dir = os.path.dirname(absolute_file_path)
+	wkv6_cuda = load(name="wkv6", sources=[f"{cur_dir}/cuda/wkv6_op.cpp", f"{cur_dir}/cuda/wkv6_cuda.cu"],
+	                 verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3",
+	                                                  "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}",
+	                                                  f"-D_T_={RWKV_CTXLEN}"])
+
+# dtype = torch.float
+dtype = torch.bfloat16
+class WKV_6(torch.autograd.Function):
+	@staticmethod
+	def forward(ctx, B, T, C, H, r, k, v, w, u):
+		with torch.no_grad():
+			# assert r.dtype == torch.bfloat16
+			# assert k.dtype == torch.bfloat16
+			# assert v.dtype == torch.bfloat16
+			# assert w.dtype == torch.bfloat16
+			# assert u.dtype == torch.bfloat16
+			# assert HEAD_SIZE == C // H
+			ctx.B = B
+			ctx.T = T
+			ctx.C = C
+			ctx.H = H
+			assert r.is_contiguous()
+			assert k.is_contiguous()
+			assert v.is_contiguous()
+			assert w.is_contiguous()
+			assert u.is_contiguous()
+			ew = (-torch.exp(w.float())).contiguous()
+			ctx.save_for_backward(r, k, v, ew, u)
+			y = torch.empty((B, T, C), device=r.device, dtype=dtype,
+			                memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
+			wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y)
+			return y
+	
+	@staticmethod
+	def backward(ctx, gy):
+		with torch.no_grad():
+			# assert gy.dtype == torch.bfloat16
+			B = ctx.B
+			T = ctx.T
+			C = ctx.C
+			H = ctx.H
+			assert gy.is_contiguous()
+			r, k, v, ew, u = ctx.saved_tensors
+			gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
+			                 memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
+			gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
+			                 memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
+			gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
+			                 memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
+			gw = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
+			                 memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
+			gu = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=dtype,
+			                 memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
+			wkv6_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu)
+			gu = torch.sum(gu, 0).view(H, C // H)
+			return (None, None, None, None, gr, gk, gv, gw, gu)
+
+
+def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u):
+	return WKV_6.apply(B, T, C, H, r, k, v, w, u)
+
+
+class RWKV_Tmix_x060(MyModule):
+	def __init__(self, args, layer_id):
+		super().__init__()
+		self.args = args
+		
+		load_rwkv_kernel(args.head_size_a, args.ctx_len)
+		
+		self.layer_id = layer_id
+		
+		self.head_size = args.head_size_a
+		self.n_head = args.dim_att // self.head_size
+		assert args.dim_att % self.n_head == 0
+		
+		with torch.no_grad():
+			ratio_0_to_1 = layer_id / (args.n_layer - 1)  # 0 to 1
+			ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
+			ddd = torch.ones(1, 1, args.n_embd)
+			for i in range(args.n_embd):
+				ddd[0, 0, i] = i / args.n_embd
+			
+			# fancy time_mix
+			self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+			self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+			self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+			self.time_maa_v = nn.Parameter(1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1))
+			self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+			self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+			
+			D_MIX_LORA = 32  # generate TIME_MIX for w,k,v,r,g
+			self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
+			self.time_maa_w2 = nn.Parameter(torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01))
+			
+			# fancy time_decay
+			decay_speed = torch.ones(args.dim_att)
+			for n in range(args.dim_att):
+				decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+			self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
+			
+			D_DECAY_LORA = 64
+			self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
+			self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01))
+			
+			tmp = torch.zeros(args.dim_att)
+			for n in range(args.dim_att):
+				zigzag = ((n + 1) % 3 - 1) * 0.1
+				tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
+			
+			self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
+		
+		self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+		self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
+		self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
+		
+		self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
+		self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
+		self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
+		self.ln_x = nn.GroupNorm(self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor ** 2))
+	
+	@MyFunction
+	def jit_func(self, x):
+		B, T, C = x.size()
+		
+		xx = self.time_shift(x) - x
+		
+		xxx = x + xx * self.time_maa_x
+		xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
+		xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
+		mw, mk, mv, mr, mg = xxx.unbind(dim=0)
+		
+		xw = x + xx * (self.time_maa_w + mw)
+		xk = x + xx * (self.time_maa_k + mk)
+		xv = x + xx * (self.time_maa_v + mv)
+		xr = x + xx * (self.time_maa_r + mr)
+		xg = x + xx * (self.time_maa_g + mg)
+		
+		r = self.receptance(xr)
+		k = self.key(xk)
+		v = self.value(xv)
+		g = F.silu(self.gate(xg))
+		
+		ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
+		w = self.time_decay + ww
+		
+		return r, k, v, g, w
+	
+	@MyFunction
+	def jit_func_2(self, x, g):
+		B, T, C = x.size()
+		x = x.view(B * T, C)
+		
+		x = self.ln_x(x).view(B, T, C)
+		x = self.output(x * g)
+		return x
+	
+	def forward(self, x):
+		B, T, C = x.size()
+		H = self.n_head
+		
+		r, k, v, g, w = self.jit_func(x)
+		x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
+		
+		return self.jit_func_2(x, g)
+
+class RWKV_CMix_x060(MyModule):
+	def __init__(self, args, layer_id):
+		super().__init__()
+		self.args = args
+		self.layer_id = layer_id
+		self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+		
+		with torch.no_grad():  # fancy init of time_mix
+			ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
+			ddd = torch.ones(1, 1, args.n_embd)
+			for i in range(args.n_embd):
+				ddd[0, 0, i] = i / args.n_embd
+			self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+			self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+		
+		self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
+		self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
+		self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
+	
+	@MyFunction
+	def forward(self, x):
+		xx = self.time_shift(x) - x
+		xk = x + xx * self.time_maa_k
+		xr = x + xx * self.time_maa_r
+		
+		k = self.key(xk)
+		k = torch.relu(k) ** 2
+		kv = self.value(k)
+		return torch.sigmoid(self.receptance(xr)) * kv
+
+
+class Block(nn.Module):
+	def __init__(self, args, layer_id):
+		super().__init__()
+		self.args = args
+		self.layer_id = layer_id
+		
+		self.ln1 = nn.LayerNorm(args.n_embd)
+		self.ln2 = nn.LayerNorm(args.n_embd)
+		
+		if self.layer_id == 0:
+			self.ln0 = nn.LayerNorm(args.n_embd)
+
+
+		self.att = RWKV_Tmix_x060(args, layer_id)
+		
+		self.ffn = RWKV_CMix_x060(args, layer_id)
+
+
+		if args.dropout > 0:
+			self.drop0 = nn.Dropout(p=args.dropout)
+			self.drop1 = nn.Dropout(p=args.dropout)
+	
+	def forward(self, x, x_emb=None):
+		args = self.args
+		B, T, C = x.size()
+		if self.layer_id == 0:
+			x = self.ln0(x)
+
+		
+		if self.args.dropout == 0:
+			if self.layer_id == 0 and args.pre_ffn > 0:
+				x = x + self.ffnPre(self.ln1(x))
+			else:
+				x = x + self.att(self.ln1(x))
+			x = x + self.ffn(self.ln2(x))
+		else:
+			if self.layer_id == 0 and args.pre_ffn > 0:
+				x = self.drop0(x + self.ffnPre(self.ln1(x)))
+			else:
+				x = self.drop0(x + self.att(self.ln1(x)))
+			x = self.drop1(x + self.ffn(self.ln2(x)))
+		
+		return x
+
+
+class RWKVLayer(nn.Module):
+	def __init__(self, args, layer_id):
+		super().__init__()
+		self.args = args
+		self.layer_id = layer_id
+		if args.dim_ffn is None:
+			args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
+		self.ln0 = None
+		if self.layer_id == 0 and args.get("ln0", True):
+			self.ln0 = nn.LayerNorm(args.n_embd)
+		
+		self.ln1 = None
+		if args.get("ln1", True):
+			self.ln1 = nn.LayerNorm(args.n_embd)
+		self.ln2 = nn.LayerNorm(args.n_embd)
+		
+
+		self.att = RWKV_Tmix_x060(args, layer_id)
+		
+		self.ffn = RWKV_CMix_x060(args, layer_id)
+		
+		if args.dropout > 0:
+			self.drop0 = nn.Dropout(p=args.dropout)
+			self.drop1 = nn.Dropout(p=args.dropout)
+		
+		# init
+		if args.get("init_rwkv", True):
+			print("init_rwkv")
+			nn.init.orthogonal_(self.att.receptance.weight, gain=1)
+			nn.init.orthogonal_(self.att.key.weight, gain=0.1)
+			nn.init.orthogonal_(self.att.value.weight, gain=1)
+			nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
+			nn.init.zeros_(self.att.output.weight)
+			
+			nn.init.orthogonal_(self.ffn.key.weight, gain=1)
+			nn.init.zeros_(self.ffn.value.weight)
+			nn.init.zeros_(self.ffn.receptance.weight)
+			scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+			nn.init.constant_(self.ln2.weight, scale)
+			if self.ln0 is not None:
+				nn.init.constant_(self.ln0.weight, scale)
+			if self.ln1 is not None:
+				nn.init.constant_(self.ln1.weight, scale)
+		
+	def forward(self, x, x_emb=None, mask=None, **kwargs):
+		
+		args = self.args
+		if args.get("datatype", "bf16") == "bf16":
+			x = x.bfloat16()
+		B, T, C = x.size()
+		if self.layer_id == 0 and self.ln0 is not None:
+			x = self.ln0(x)
+		
+		if self.args.dropout == 0:
+			if self.ln1 is None:
+				x = x + self.att(x)
+			else:
+				x = x + self.att(self.ln1(x))
+			x = x + self.ffn(self.ln2(x))
+		else:
+			if self.ln1 is None:
+				x = self.drop0(x + self.att(x))
+			else:
+				x = self.drop0(x + self.att(self.ln1(x)))
+			x = self.drop1(x + self.ffn(self.ln2(x)))
+		
+		if args.get("datatype", "bf16") == "bf16":
+			x = x.to(torch.float32)
+		return x
+
+
+class RWKV(nn.Module):
+	def __init__(self, args):
+		super().__init__()
+		self.args = args
+		if not hasattr(args, 'dim_att'):
+			args.dim_att = args.n_embd
+		if not hasattr(args, 'dim_ffn'):
+			if '-f4' in os.environ["RWKV_MY_TESTING"]:
+				args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
+			else:
+				args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)  # default = 3.5x emb size
+		if not hasattr(args, 'tiny_att_layer'):
+			args.tiny_att_layer = -1
+		if not hasattr(args, 'tiny_att_dim'):
+			args.tiny_att_dim = -1
+		assert args.n_embd % 32 == 0
+		assert args.dim_att % 32 == 0
+		assert args.dim_ffn % 32 == 0
+		
+		self.emb = nn.Embedding(args.vocab_size, args.n_embd)
+		
+		self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
+		
+		self.ln_out = nn.LayerNorm(args.n_embd)
+		self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
+
+
+		if args.dropout > 0:
+			self.drop0 = nn.Dropout(p=args.dropout)
+
+
+	def forward(self, idx):
+		args = self.args
+		B, T = idx.size()
+		assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
+		
+		x = self.emb(idx)
+		x_emb = x
+		
+		if args.dropout > 0:
+			x = self.drop0(x)
+		if args.tiny_att_dim > 0:
+			for block in self.blocks:
+				if args.grad_cp == 1:
+					x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
+				else:
+					x = block(x, x_emb)
+		else:
+			for block in self.blocks:
+				if args.grad_cp == 1:
+					x = deepspeed.checkpointing.checkpoint(block, x)
+				else:
+					x = block(x)
+		
+		x = self.ln_out(x)
+		
+		if args.head_qk > 0:
+			q = self.head_q(x)[:, :T, :]
+			k = self.head_k(x)[:, :T, :]
+			c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
+			c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
+			
+			if "32" in os.environ["RWKV_FLOAT_MODE"]:
+				c = c @ F.one_hot(idx, num_classes=args.vocab_size)
+			elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
+				c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
+			elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+				c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
+			
+			x = self.head(x) + c
+		else:
+			x = self.head(x)
+		
+		return x
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index ca960f1..5f7caeb 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -261,7 +261,9 @@
             self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
         )
         all_heads[self.dims.n_text_layer // 2 :] = True
-        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+        # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+        # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense()
+        # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False)
 
     def set_alignment_heads(self, dump: bytes):
         array = np.frombuffer(

--
Gitblit v1.9.1