zhifu gao
2024-04-23 0a4a1d5257dace9561d95b38a9386539908dcd5e
Dev gzf exp (#1645)

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* sensevoice finetune

* bugfix

* update with main (#1631)

* update seaco finetune

* v1.0.24

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* update with main (#1638)

* update seaco finetune

* v1.0.24

* update rwkv template

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sensevoice

* sense voice

* sense voice

* sense voice

* sense voice

---------

Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
6个文件已修改
10个文件已添加
2394 ■■■■■ 已修改文件
examples/aishell/conformer/conf/conformer_rwkv.yaml 124 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/espnet_samplers.py 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/audio_datasets/index_ds.py 202 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/conformer_rwkv/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/conformer_rwkv/decoder.py 507 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/conformer_rwkv/model.py 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/conformer_rwkv/template.yaml 123 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/cuda/wkv5_cuda.cu 202 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/cuda/wkv5_op.cpp 22 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/cuda/wkv6_cuda.cu 243 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/cuda/wkv6_op.cpp 23 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/decoder.py 273 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py 211 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/rwkv_v6.py 425 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/model.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/conformer/conf/conformer_rwkv.yaml
New file
@@ -0,0 +1,124 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: Conformer
model_conf:
    ctc_weight: 0.3
    lsm_weight: 0.1     # label smoothing option
    length_normalized_loss: false
# encoder
encoder: ConformerEncoder
encoder_conf:
    output_size: 256    # dimension of attention
    attention_heads: 4
    linear_units: 2048  # the number of units of position-wise feed forward
    num_blocks: 12      # the number of encoder blocks
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.0
    input_layer: conv2d # encoder architecture type
    normalize_before: true
    pos_enc_layer_type: rel_pos
    selfattention_layer_type: rel_selfattn
    activation_type: swish
    macaron_style: true
    use_cnn_module: true
    cnn_module_kernel: 15
# decoder
decoder: TransformerRWKVDecoder
decoder_conf:
    attention_heads: 4
    linear_units: 2048
    num_blocks: 6
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.0
    src_attention_dropout_rate: 0.0
    input_layer: embed
    rwkv_cfg:
      n_embd: 256
      dropout: 0
      head_size_a: 64
      ctx_len: 512
      dim_att: 256 #${model_conf.rwkv_cfg.n_embd}
      dim_ffn: null
      head_size_divisor: 4
      n_layer: 6
      pre_ffn: 0
      ln0: false
      ln1: false
      init_rwkv: true
# frontend related
frontend: WavFrontend
frontend_conf:
    fs: 16000
    window: hamming
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 1
    lfr_n: 1
specaug: SpecAug
specaug_conf:
    apply_time_warp: true
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    num_freq_mask: 2
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 40
    num_time_mask: 2
train_conf:
  accum_grad: 1
  grad_clip: 5
  max_epoch: 150
  keep_nbest_models: 10
  log_interval: 50
optim: adam
optim_conf:
   lr: 0.0005
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
    index_ds: IndexDSJsonl
    batch_sampler: EspnetStyleBatchSampler
    batch_type: length # example or length
    batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
    buffer_size: 1024
    shuffle: True
    num_workers: 4
    preprocessor_speech: SpeechPreprocessSpeedPerturb
    preprocessor_speech_conf:
      speed_perturb: [0.9, 1.0, 1.1]
tokenizer: CharTokenizer
tokenizer_conf:
  unk_symbol: <unk>
ctc_conf:
    dropout_rate: 0.0
    ctc_type: builtin
    reduce: true
    ignore_nan_grad: true
normalize: null
funasr/bin/train.py
@@ -90,7 +90,8 @@
    # freeze_param
    freeze_param = kwargs.get("freeze_param", None)
    if freeze_param is not None:
        freeze_param = eval(freeze_param)
        if "," in freeze_param:
            freeze_param = eval(freeze_param)
        if isinstance(freeze_param, Sequence):
            freeze_param = (freeze_param,)
        logging.info("freeze_param is not None: %s", freeze_param)
@@ -104,7 +105,7 @@
    if use_ddp:
        model = model.cuda(local_rank)
        model = DDP(model, device_ids=[local_rank],
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True))
                    find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
    elif use_fsdp:
        # model = FSDP(model).cuda(local_rank)
funasr/datasets/audio_datasets/espnet_samplers.py
@@ -32,8 +32,9 @@
    def __init__(self, dataset,
                 batch_size,
                 batch_type="token",
                 num_replicas=None,
                 rank=None,
                 num_replicas=None,
                 rank_split=False,
                 shuffle=True,
                 drop_last=False,
                 is_training: bool = True,
@@ -45,6 +46,10 @@
            rank = dist.get_rank()
            num_replicas = dist.get_world_size()
        except:
            rank = 0
            num_replicas = 1
        if rank_split:
            logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank")
            rank = 0
            num_replicas = 1
        self.rank = rank
@@ -65,8 +70,8 @@
        self.length_scale_source = kwargs.get("length_scale_source", 1.0)
        super().__init__(dataset, num_replicas=num_replicas, rank=rank,
                         shuffle=shuffle, drop_last=drop_last)
        # super().__init__(dataset, num_replicas=num_replicas, rank=rank,
        #                  shuffle=shuffle, drop_last=drop_last)
    def __iter__(self):
        if self.shuffle:
            g = torch.Generator()
funasr/datasets/audio_datasets/index_ds.py
@@ -9,66 +9,66 @@
from funasr.register import tables
@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
    def __init__(self, path):
        super().__init__()
        contents = []
        with open(path, encoding='utf-8') as fin:
            for line in fin:
                data = json.loads(line.strip())
                if "text" in data:  # for sft
                    self.contents.append(data['text'])
                if "source" in data:  # for speech lab pretrain
                    prompt = data["prompt"]
                    source = data["source"]
                    target = data["target"]
                    source_len = data["source_len"]
                    target_len = data["target_len"]
                    contents.append({"source": source,
                                     "prompt": prompt,
                                     "target": target,
                                     "source_len": source_len,
                                     "target_len": target_len,
                                     }
                                    )
        self.contents = []
        total_num = len(contents)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except:
            rank = 0
            world_size = 1
            logging.warning("distributed is not initialized, only single shard")
        num_per_rank = total_num // world_size
        # rank = 0
        # import ipdb; ipdb.set_trace()
        self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
        logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
    def __len__(self):
        return len(self.contents)
    def __getitem__(self, index):
        try:
            data = self.contents[index]
        except:
            print(index)
        return data
    def get_source_len(self, data_dict):
        return data_dict["source_len"]
    def get_target_len(self, data_dict):
        return data_dict["target_len"] if "target_len" in data_dict else 0
# @tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
# class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
#
#     def __init__(self, path):
#         super().__init__()
#
#         contents = []
#         with open(path, encoding='utf-8') as fin:
#             for line in fin:
#                 data = json.loads(line.strip())
#                 if "text" in data:  # for sft
#                     self.contents.append(data['text'])
#                 if "source" in data:  # for speech lab pretrain
#                     prompt = data["prompt"]
#                     source = data["source"]
#                     target = data["target"]
#                     source_len = data["source_len"]
#                     target_len = data["target_len"]
#
#                     contents.append({"source": source,
#                                      "prompt": prompt,
#                                      "target": target,
#                                      "source_len": source_len,
#                                      "target_len": target_len,
#                                      }
#                                     )
#
#         self.contents = []
#         total_num = len(contents)
#         try:
#             rank = dist.get_rank()
#             world_size = dist.get_world_size()
#         except:
#             rank = 0
#             world_size = 1
#             logging.warning("distributed is not initialized, only single shard")
#         num_per_rank = total_num // world_size
#
#         # rank = 0
#         # import ipdb; ipdb.set_trace()
#         self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
#
#         logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
#
#     def __len__(self):
#         return len(self.contents)
#
#     def __getitem__(self, index):
#         try:
#             data = self.contents[index]
#         except:
#             print(index)
#         return data
#
#     def get_source_len(self, data_dict):
#         return data_dict["source_len"]
#
#     def get_target_len(self, data_dict):
#
#         return data_dict["target_len"] if "target_len" in data_dict else 0
@tables.register("index_ds_classes", "IndexDSJsonl")
@tables.register("index_ds_classes", "IndexDSJsonlRankFull")
@@ -143,3 +143,85 @@
    def get_target_len(self, data_dict):
        
        return data_dict.get("target_len", 0)
@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
    def __init__(self, path: str, **kwargs):
        super().__init__()
        self.max_source_length = kwargs.get("max_source_length", 2048)
        self.min_source_length = kwargs.get("min_source_length", 0)
        self.max_target_length = kwargs.get("max_target_length", 2048)
        self.min_target_length = kwargs.get("min_target_length", 0)
        with open(path, encoding='utf-8') as fin:
            file_list = fin.readlines()
        total_num = len(file_list)
        try:
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        except:
            rank = 0
            world_size = 1
            logging.warning("distributed is not initialized, only single shard")
        num_per_rank = total_num // world_size
        if num_per_rank * world_size < total_num:
            logging.warning(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}")
        file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank]
        contents = []
        for file_json in file_list_rank:
            with open(file_json.strip(), encoding='utf-8') as fin:
                for line in fin:
                    data = json.loads(line.strip())
                    if "text" in data:  # for sft
                        contents.append(data['text'])
                    if "source" in data:  # for speech lab pretrain
                        prompt = data.get("prompt", "<ASR>")
                        source = data["source"].replace("/cpfs01", "/cpfs_speech/data")
                        target = data["target"]
                        source_len = data.get("source_len", 1)
                        target_len = data.get("target_len", 0)
                        if source_len < self.min_source_length or source_len > self.max_source_length:
                            continue
                        if target_len < self.min_target_length or target_len > self.max_target_length:
                            continue
                        contents_i = {"source": source,
                                      "prompt": prompt,
                                      "target": target,
                                      "source_len": source_len,
                                      "target_len": target_len,
                                      }
                        text_language = data.get("text_language", None)
                        if text_language is not None:
                            contents_i["text_language"] = text_language
                        audio_language = data.get("audio_language", None)
                        if audio_language is not None:
                            contents_i["audio_language"] = audio_language
                        contents.append(contents_i)
        self.contents = contents
        logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}")
    def __len__(self):
        return len(self.contents)
    def __getitem__(self, index):
        try:
            data = self.contents[index]
        except:
            print(index)
        return data
    def get_source_len(self, data_dict):
        return data_dict.get("source_len", 1)
    def get_target_len(self, data_dict):
        return data_dict.get("target_len", 0)
funasr/models/conformer_rwkv/__init__.py
funasr/models/conformer_rwkv/decoder.py
New file
@@ -0,0 +1,507 @@
# Copyright 2019 Shigeki Karita
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""Decoder definition."""
from typing import Any
from typing import List
from typing import Sequence
from typing import Tuple
import torch
from torch import nn
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.lightconv import LightweightConvolution
from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
from funasr.models.transformer.utils.mask import subsequent_mask
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.positionwise_feed_forward import (
    PositionwiseFeedForward,  # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
from omegaconf import OmegaConf
from funasr.register import tables
class DecoderLayer(nn.Module):
    """Single decoder layer module.
    Args:
        size (int): Input dimension.
        self_attn (torch.nn.Module): Self-attention module instance.
            `MultiHeadedAttention` instance can be used as the argument.
        src_attn (torch.nn.Module): Self-attention module instance.
            `MultiHeadedAttention` instance can be used as the argument.
        feed_forward (torch.nn.Module): Feed-forward module instance.
            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
            can be used as the argument.
        dropout_rate (float): Dropout rate.
        normalize_before (bool): Whether to use layer_norm before the first block.
        concat_after (bool): Whether to concat attention layer's input and output.
            if True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            if False, no additional linear will be applied. i.e. x -> x + att(x)
    """
    def __init__(
            self,
            size,
            self_attn,
            src_attn,
            feed_forward,
            dropout_rate,
            normalize_before=True,
            concat_after=False,
            layer_id=None,
            args={},
    ):
        """Construct an DecoderLayer object."""
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn.to(torch.bfloat16)
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.norm1 = LayerNorm(size)
        self.norm2 = LayerNorm(size)
        self.norm3 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        if self.concat_after:
            self.concat_linear1 = nn.Linear(size + size, size)
            self.concat_linear2 = nn.Linear(size + size, size)
        self.layer_id = layer_id
        self.ln0 = None
        if self.layer_id == 0 and not args.get("ln0", True):
            self.ln0 = LayerNorm(args.n_embd)
            if args.get("init_rwkv", True):
                print("init_rwkv")
                layer_id = 0
                scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
                nn.init.constant_(self.ln0.weight, scale)
        # init
        if args.get("init_rwkv", True):
            print("init_rwkv")
            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
            nn.init.constant_(self.norm1.weight, scale)
            nn.init.constant_(self.self_attn.ln2.weight, scale)
    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
        """Compute decoded features.
        Args:
            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
            cache (List[torch.Tensor]): List of cached tensors.
                Each tensor shape should be (#batch, maxlen_out - 1, size).
        Returns:
            torch.Tensor: Output tensor(#batch, maxlen_out, size).
            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
        """
        if self.layer_id == 0 and self.ln0 is not None:
            tgt = self.ln0(tgt)
        residual = tgt
        tgt = self.norm1(tgt)
        if cache is None:
            x = residual + self.dropout(self.self_attn(tgt, mask=tgt_mask))
        else:
            # tgt_q = tgt[:, -1:, :]
            # residual_q = residual[:, -1:, :]
            tgt_q_mask = None
            x = residual + self.dropout(self.self_attn(tgt, mask=tgt_q_mask))
            x = x[:, -1, :]
        # x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
        residual = x
        x = self.norm2(x)
        x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
        residual = x
        x = self.norm3(x)
        x = residual + self.dropout(self.feed_forward(x))
        if cache is not None:
            x = torch.cat([cache, x], dim=1)
        return x, tgt_mask, memory, memory_mask
class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
    """Base class of Transfomer decoder module.
    Args:
        vocab_size: output dim
        encoder_output_size: dimension of attention
        attention_heads: the number of heads of multi head attention
        linear_units: the number of units of position-wise feed forward
        num_blocks: the number of decoder blocks
        dropout_rate: dropout rate
        self_attention_dropout_rate: dropout rate for attention
        input_layer: input layer type
        use_output_layer: whether to use output layer
        pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
        normalize_before: whether to use layer_norm before the first block
        concat_after: whether to concat attention layer's input and output
            if True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            if False, no additional linear will be applied.
            i.e. x -> x + att(x)
    """
    def __init__(
            self,
            vocab_size: int,
            encoder_output_size: int,
            dropout_rate: float = 0.1,
            positional_dropout_rate: float = 0.1,
            input_layer: str = "embed",
            use_output_layer: bool = True,
            pos_enc_class=PositionalEncoding,
            normalize_before: bool = True,
    ):
        super().__init__()
        attention_dim = encoder_output_size
        if input_layer == "embed":
            self.embed = torch.nn.Sequential(
                torch.nn.Embedding(vocab_size, attention_dim),
                pos_enc_class(attention_dim, positional_dropout_rate),
            )
        elif input_layer == "linear":
            self.embed = torch.nn.Sequential(
                torch.nn.Linear(vocab_size, attention_dim),
                torch.nn.LayerNorm(attention_dim),
                torch.nn.Dropout(dropout_rate),
                torch.nn.ReLU(),
                pos_enc_class(attention_dim, positional_dropout_rate),
            )
        else:
            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
        self.normalize_before = normalize_before
        if self.normalize_before:
            self.after_norm = LayerNorm(attention_dim)
        if use_output_layer:
            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
        else:
            self.output_layer = None
        # Must set by the inheritance
        self.decoders = None
    def forward(
            self,
            hs_pad: torch.Tensor,
            hlens: torch.Tensor,
            ys_in_pad: torch.Tensor,
            ys_in_lens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
        Args:
            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
            hlens: (batch)
            ys_in_pad:
                input token ids, int64 (batch, maxlen_out)
                if input_layer == "embed"
                input tensor (batch, maxlen_out, #mels) in the other cases
            ys_in_lens: (batch)
        Returns:
            (tuple): tuple containing:
            x: decoded token score before softmax (batch, maxlen_out, token)
                if use_output_layer is True,
            olens: (batch, )
        """
        tgt = ys_in_pad
        # tgt_mask: (B, 1, L)
        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
        # m: (1, L, L)
        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
        # tgt_mask: (B, L, L)
        tgt_mask = tgt_mask & m
        memory = hs_pad
        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
            memory.device
        )
        # Padding for Longformer
        if memory_mask.shape[-1] != memory.shape[1]:
            padlen = memory.shape[1] - memory_mask.shape[-1]
            memory_mask = torch.nn.functional.pad(
                memory_mask, (0, padlen), "constant", False
            )
        x = self.embed(tgt)
        x, tgt_mask, memory, memory_mask = self.decoders(
            x, tgt_mask, memory, memory_mask
        )
        if self.normalize_before:
            x = self.after_norm(x)
        if self.output_layer is not None:
            x = self.output_layer(x)
        olens = tgt_mask.sum(1)
        return x, olens
    def forward_one_step(
            self,
            tgt: torch.Tensor,
            tgt_mask: torch.Tensor,
            memory: torch.Tensor,
            cache: List[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """Forward one step.
        Args:
            tgt: input token ids, int64 (batch, maxlen_out)
            tgt_mask: input token mask,  (batch, maxlen_out)
                      dtype=torch.uint8 in PyTorch 1.2-
                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
            memory: encoded memory, float32  (batch, maxlen_in, feat)
            cache: cached output list of (batch, max_time_out-1, size)
        Returns:
            y, cache: NN output value and cache per `self.decoders`.
            y.shape` is (batch, maxlen_out, token)
        """
        x = self.embed(tgt)
        if cache is None:
            cache = [None] * len(self.decoders)
        new_cache = []
        for c, decoder in zip(cache, self.decoders):
            x, tgt_mask, memory, memory_mask = decoder(
                x, tgt_mask, memory, None, cache=c
            )
            new_cache.append(x)
        if self.normalize_before:
            y = self.after_norm(x[:, -1])
        else:
            y = x[:, -1]
        if self.output_layer is not None:
            y = torch.log_softmax(self.output_layer(y), dim=-1)
        return y, new_cache
    def score(self, ys, state, x):
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp, state = self.forward_one_step(
            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
        )
        return logp.squeeze(0), state
    def batch_score(
            self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
    ) -> Tuple[torch.Tensor, List[Any]]:
        """Score new token batch.
        Args:
            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
            states (List[Any]): Scorer states for prefix tokens.
            xs (torch.Tensor):
                The encoder feature that generates ys (n_batch, xlen, n_feat).
        Returns:
            tuple[torch.Tensor, List[Any]]: Tuple of
                batchfied scores for next token with shape of `(n_batch, n_vocab)`
                and next state list for ys.
        """
        # merge states
        n_batch = len(ys)
        n_layers = len(self.decoders)
        if states[0] is None:
            batch_state = None
        else:
            # transpose state of [batch, layer] into [layer, batch]
            batch_state = [
                torch.stack([states[b][i] for b in range(n_batch)])
                for i in range(n_layers)
            ]
        # batch decoding
        ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
        logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
        # transpose state of [layer, batch] into [batch, layer]
        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
        return logp, state_list
@tables.register("decoder_classes", "TransformerRWKVDecoder")
class TransformerRWKVDecoder(BaseTransformerDecoder):
    def __init__(
            self,
            vocab_size: int,
            encoder_output_size: int,
            attention_heads: int = 4,
            linear_units: int = 2048,
            num_blocks: int = 6,
            dropout_rate: float = 0.1,
            positional_dropout_rate: float = 0.1,
            self_attention_dropout_rate: float = 0.0,
            src_attention_dropout_rate: float = 0.0,
            input_layer: str = "embed",
            use_output_layer: bool = True,
            pos_enc_class=PositionalEncoding,
            normalize_before: bool = True,
            concat_after: bool = False,
            **kwargs,
    ):
        super().__init__(
            vocab_size=vocab_size,
            encoder_output_size=encoder_output_size,
            dropout_rate=dropout_rate,
            positional_dropout_rate=positional_dropout_rate,
            input_layer=input_layer,
            use_output_layer=use_output_layer,
            pos_enc_class=pos_enc_class,
            normalize_before=normalize_before,
        )
        from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
        rwkv_cfg = kwargs.get("rwkv_cfg", {})
        args = OmegaConf.create(rwkv_cfg)
        # self.attn = RWKVLayer(args=args, layer_id=layer_id)
        attention_dim = encoder_output_size
        self.decoders = repeat(
            num_blocks,
            lambda lnum: DecoderLayer(
                attention_dim,
                RWKVLayer(args=args, layer_id=lnum),
                MultiHeadedAttention(
                    attention_heads, attention_dim, src_attention_dropout_rate
                ),
                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                dropout_rate,
                normalize_before,
                concat_after,
                lnum,
                args=args,
            ),
        )
        # init
        if args.get("init_rwkv", True):
            print("init_rwkv")
            nn.init.uniform_(self.embed[0].weight, a=-1e-4, b=1e-4)
    def forward(
            self,
            hs_pad: torch.Tensor,
            hlens: torch.Tensor,
            ys_in_pad: torch.Tensor,
            ys_in_lens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
        Args:
            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
            hlens: (batch)
            ys_in_pad:
                input token ids, int64 (batch, maxlen_out)
                if input_layer == "embed"
                input tensor (batch, maxlen_out, #mels) in the other cases
            ys_in_lens: (batch)
        Returns:
            (tuple): tuple containing:
            x: decoded token score before softmax (batch, maxlen_out, token)
                if use_output_layer is True,
            olens: (batch, )
        """
        tgt = ys_in_pad
        # tgt_mask: (B, 1, L)
        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
        # m: (1, L, L)
        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
        # tgt_mask: (B, L, L)
        tgt_mask = tgt_mask & m
        memory = hs_pad
        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
            memory.device
        )
        # Padding for Longformer
        if memory_mask.shape[-1] != memory.shape[1]:
            padlen = memory.shape[1] - memory_mask.shape[-1]
            memory_mask = torch.nn.functional.pad(
                memory_mask, (0, padlen), "constant", False
            )
        x = self.embed(tgt)
        x, tgt_mask, memory, memory_mask = self.decoders(
            x, tgt_mask, memory, memory_mask
        )
        if self.normalize_before:
            x = self.after_norm(x)
        if self.output_layer is not None:
            x = self.output_layer(x)
        olens = tgt_mask.sum(1)
        return x, olens
    def forward_one_step(
            self,
            tgt: torch.Tensor,
            tgt_mask: torch.Tensor,
            memory: torch.Tensor,
            cache: List[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """Forward one step.
        Args:
            tgt: input token ids, int64 (batch, maxlen_out)
            tgt_mask: input token mask,  (batch, maxlen_out)
                      dtype=torch.uint8 in PyTorch 1.2-
                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
            memory: encoded memory, float32  (batch, maxlen_in, feat)
            cache: cached output list of (batch, max_time_out-1, size)
        Returns:
            y, cache: NN output value and cache per `self.decoders`.
            y.shape` is (batch, maxlen_out, token)
        """
        x = self.embed(tgt)
        if cache is None:
            cache = [None] * len(self.decoders)
        new_cache = []
        for c, decoder in zip(cache, self.decoders):
            x, tgt_mask, memory, memory_mask = decoder(
                x, tgt_mask, memory, None, cache=c
            )
            new_cache.append(x)
        if self.normalize_before:
            y = self.after_norm(x[:, -1])
        else:
            y = x[:, -1]
        if self.output_layer is not None:
            y = torch.log_softmax(self.output_layer(y), dim=-1)
        return y, new_cache
funasr/models/conformer_rwkv/model.py
New file
@@ -0,0 +1,19 @@
import logging
import torch
from funasr.models.transformer.model import Transformer
from funasr.register import tables
@tables.register("model_classes", "Conformer")
class Conformer(Transformer):
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
funasr/models/conformer_rwkv/template.yaml
New file
@@ -0,0 +1,123 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: Conformer
model_conf:
    ctc_weight: 0.3
    lsm_weight: 0.1     # label smoothing option
    length_normalized_loss: false
# encoder
encoder: ConformerEncoder
encoder_conf:
    output_size: 256    # dimension of attention
    attention_heads: 4
    linear_units: 2048  # the number of units of position-wise feed forward
    num_blocks: 12      # the number of encoder blocks
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.0
    input_layer: conv2d # encoder architecture type
    normalize_before: true
    pos_enc_layer_type: rel_pos
    selfattention_layer_type: rel_selfattn
    activation_type: swish
    macaron_style: true
    use_cnn_module: true
    cnn_module_kernel: 15
# decoder
decoder: TransformerRWKVDecoder
decoder_conf:
    attention_heads: 4
    linear_units: 2048
    num_blocks: 6
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.0
    src_attention_dropout_rate: 0.0
    input_layer: embed
    rwkv_cfg:
      n_embd: 256
      dropout: 0
      head_size_a: 64
      ctx_len: 512
      dim_att: 256 #${model_conf.rwkv_cfg.n_embd}
      dim_ffn: null
      head_size_divisor: 4
      n_layer: 6
      pre_ffn: 0
      ln0: false
      ln1: false
# frontend related
frontend: WavFrontend
frontend_conf:
    fs: 16000
    window: hamming
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 1
    lfr_n: 1
specaug: SpecAug
specaug_conf:
    apply_time_warp: true
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    num_freq_mask: 2
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 40
    num_time_mask: 2
train_conf:
  accum_grad: 1
  grad_clip: 5
  max_epoch: 150
  keep_nbest_models: 10
  log_interval: 50
optim: adam
optim_conf:
   lr: 0.0005
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
    index_ds: IndexDSJsonl
    batch_sampler: EspnetStyleBatchSampler
    batch_type: length # example or length
    batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
    buffer_size: 1024
    shuffle: True
    num_workers: 4
    preprocessor_speech: SpeechPreprocessSpeedPerturb
    preprocessor_speech_conf:
      speed_perturb: [0.9, 1.0, 1.1]
tokenizer: CharTokenizer
tokenizer_conf:
  unk_symbol: <unk>
ctc_conf:
    dropout_rate: 0.0
    ctc_type: builtin
    reduce: true
    ignore_nan_grad: true
normalize: null
funasr/models/sense_voice/cuda/wkv5_cuda.cu
New file
@@ -0,0 +1,202 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
                               const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
                               F *__restrict__ const _y)
{
    const int b = blockIdx.x / H;
    const int h = blockIdx.x % H;
    const int i = threadIdx.x;
    _w += h*_N_;
    _u += h*_N_;
    __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
    float state[_N_] = {0};
    __syncthreads();
    w[i] = _w[i];
    u[i] = float(_u[i]);
    __syncthreads();
    for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
    {
        __syncthreads();
        r[i] = float(_r[t]);
        k[i] = float(_k[t]);
        __syncthreads();
        const float v = float(_v[t]);
        float y = 0;
        #pragma unroll
        for (int j = 0; j < _N_; j+=4)
        {
            const float4& r_ = (float4&)(r[j]);
            const float4& k_ = (float4&)(k[j]);
            const float4& w_ = (float4&)(w[j]);
            const float4& u_ = (float4&)(u[j]);
            float4& s = (float4&)(state[j]);
            float4 x;
            x.x = k_.x * v;
            x.y = k_.y * v;
            x.z = k_.z * v;
            x.w = k_.w * v;
            y += r_.x * (u_.x * x.x + s.x);
            y += r_.y * (u_.y * x.y + s.y);
            y += r_.z * (u_.z * x.z + s.z);
            y += r_.w * (u_.w * x.w + s.w);
            s.x = s.x * w_.x + x.x;
            s.y = s.y * w_.y + x.y;
            s.z = s.z * w_.z + x.z;
            s.w = s.w * w_.w + x.w;
        }
        _y[t] = F(y);
    }
}
template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C, const int H,
    const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
    F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
{
    const int b = blockIdx.x / H;
    const int h = blockIdx.x % H;
    const int i = threadIdx.x;
    _w += h*_N_;
    _u += h*_N_;
    __w += h*_N_;
    __shared__ float w_[_N_], u_[_N_];
    __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_];
    __syncthreads();
    w_[i] = _w[i];
    u_[i] = float(_u[i]);
    __syncthreads();
    const float w = w_[i];
    const float ww = __w[i];
    const float u = u_[i];
    float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
    float gw = 0, gu = 0;
    const int t000 = b*T*C + h*_N_ + i;
    const int t111 = (b+1)*T*C + h*_N_ + i;
    const int t222 = t111 - 2*C;
    for (int t = t000; t < t111; t += C)
    {
        __syncthreads();
        v[i] = float(_v[t]);
        gy[i] = float(_gy[t]);
        __syncthreads();
        const float k = float(_k[t]);
        float gr = 0, gu_ = 0;
        #pragma unroll
        for (int j = 0; j < _N_; j++)
        {
            float& s = state[j];
            float x = k * v[j];
            gr += (u * x + s) * gy[j];
            gu_ += x * gy[j];
            s = s * w + x;
        }
        _gr[t] = F(gr);
        gu += float(_r[t]) * gu_;
    }
    _gu[b*C + h*_N_ + i] = F(gu);
    for (int t = t000; t < t222; t += C)
    {
        __syncthreads();
        v[i] = float(_v[t]);
        gy[i] = float(_gy[t + 2*C]);
        __syncthreads();
        const float k = float(_k[t]);
        float gw_ = 0;
        #pragma unroll
        for (int j = 0; j < _N_; j++)
        {
            float& s = saaaa[j];
            float& s2 = sbbbb[j];
            float x = k * v[j];
            float tmp = w * (x + s);
            s = tmp;
            s2 = tmp + w * s2;
            gw_ += s2 * gy[j];
        }
        gw += float(_r[t + 2*C]) * gw_;
    }
    _gw[b*C + h*_N_ + i] = F(ww * gw);
    for (int t = t111 - C; t >= t000; t -= C)
    {
        __syncthreads();
        v[i] = float(_v[t]);
        gy[i] = float(_gy[t]);
        __syncthreads();
        const float rr = float(_r[t]);
        float gk = 0;
        #pragma unroll
        for (int j = 0; j < _N_; j++)
        {
            float& s = scccc[j];
            float x = rr * gy[j];
            gk += (u * x + s) * v[j];
            s = x + s * w;
        }
        _gk[t] = F(gk);
    }
    for (int t = t111 - C; t >= t000; t -= C)
    {
        __syncthreads();
        r[i] = float(_r[t]);
        k[i] = float(_k[t]);
        __syncthreads();
        const float gyy = float(_gy[t]);
        float gv = 0;
        #pragma unroll
        for (int j = 0; j < _N_; j++)
        {
            float& s = sdddd[j];
            float x = gyy * r[j];
            gv += (u_[j] * x + s) * k[j];
            s = x + s * w_[j];
        }
        _gv[t] = F(gv);
    }
}
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
{
    assert(H*_N_ == C);
    assert(_N_%4 == 0);
    kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
}
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
{
    assert(H*_N_ == C);
    assert(_N_%4 == 0);
    kernel_backward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
}
funasr/models/sense_voice/cuda/wkv5_op.cpp
New file
@@ -0,0 +1,22 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
    cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
    cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "wkv5 forward");
    m.def("backward", &backward, "wkv5 backward");
}
TORCH_LIBRARY(wkv5, m) {
    m.def("forward", forward);
    m.def("backward", backward);
}
funasr/models/sense_voice/cuda/wkv6_cuda.cu
New file
@@ -0,0 +1,243 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;
// typedef float bf16;
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
                               const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
                               F *__restrict__ const _y)
{
    const int b = blockIdx.x / H;
    const int h = blockIdx.x % H;
    const int i = threadIdx.x;
    _u += h*_N_;
    __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
    float state[_N_] = {0};
    __syncthreads();
    u[i] = float(_u[i]);
    __syncthreads();
    for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
    {
        __syncthreads();
        w[i] = exp(_w[t]);
        r[i] = float(_r[t]);
        k[i] = float(_k[t]);
        __syncthreads();
        const float v = float(_v[t]);
        float y = 0;
        #pragma unroll
        for (int j = 0; j < _N_; j+=4)
        {
            const float4& r_ = (float4&)(r[j]);
            const float4& k_ = (float4&)(k[j]);
            const float4& w_ = (float4&)(w[j]);
            const float4& u_ = (float4&)(u[j]);
            float4& s = (float4&)(state[j]);
            float4 x;
            x.x = k_.x * v;
            x.y = k_.y * v;
            x.z = k_.z * v;
            x.w = k_.w * v;
            y += r_.x * (u_.x * x.x + s.x);
            y += r_.y * (u_.y * x.y + s.y);
            y += r_.z * (u_.z * x.z + s.z);
            y += r_.w * (u_.w * x.w + s.w);
            s.x = s.x * w_.x + x.x;
            s.y = s.y * w_.y + x.y;
            s.z = s.z * w_.z + x.z;
            s.w = s.w * w_.w + x.w;
        }
        _y[t] = F(y);
    }
}
template <typename F>
__global__ void kernel_backward_111(const int B, const int T, const int C, const int H,
    const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
    F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu)
{
    const int b = blockIdx.x / H;
    const int h = blockIdx.x % H;
    const int i = threadIdx.x;
    _u += h*_N_;
    __shared__ float u_[_N_];
    __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_];
    __syncthreads();
    u_[i] = float(_u[i]);
    __syncthreads();
    const float u = u_[i];
    float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
    const int t_0 = b*T*C + h*_N_ + i;
    const int t_T_1 = t_0 + (T-1)*C;
    const int t_T = t_0 + T*C;
    float gu = 0;
    for (int t = t_0; t < t_T; t += C)
    {
        __syncthreads();
        v[i] = float(_v[t]);
        gy[i] = float(_gy[t]);
        __syncthreads();
        const float k = float(_k[t]);
        const float w = exp(_w[t]);
        float gr = 0, gu_ = 0;
        #pragma unroll
        for (int j = 0; j < _N_; j++)
        {
            float& s = state[j];
            float x = k * v[j];
            gr += (u * x + s) * gy[j];
            gu_ += x * gy[j];
            s = s * w + x;
        }
        _gr[t] = F(gr);
        gu += float(_r[t]) * gu_;
    }
    _gu[b*C + h*_N_ + i] = F(gu);
    for (int t = t_T_1; t >= t_0; t -= C)
    {
        __syncthreads();
        v[i] = float(_v[t]);
        gy[i] = float(_gy[t]);
        __syncthreads();
        const float rr = float(_r[t]);
        const float w = exp(_w[t]);
        float gk = 0;
        #pragma unroll
        for (int j = 0; j < _N_; j++)
        {
            float& s = scccc[j];
            float x = rr * gy[j];
            gk += (u * x + s) * v[j];
            s = x + s * w;
        }
        _gk[t] = F(gk);
    }
    for (int t = t_T_1; t >= t_0; t -= C)
    {
        __syncthreads();
        r[i] = float(_r[t]);
        k[i] = float(_k[t]);
        w_[i] = exp(_w[t]);
        __syncthreads();
        const float gyy = float(_gy[t]);
        float gv = 0;
        #pragma unroll
        for (int j = 0; j < _N_; j++)
        {
            float& s = sdddd[j];
            float x = gyy * r[j];
            gv += (u_[j] * x + s) * k[j];
            s = x + s * w_[j];
        }
        _gv[t] = F(gv);
    }
}
template <typename F>
__global__ void kernel_backward_222(const int B, const int T, const int C, const int H,
    const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
    F *__restrict__ const _gw)
{
    const int b = blockIdx.x / H;
    const int h = blockIdx.x % H;
    const int i = threadIdx.x;
    __shared__ float v[_N_], gy[_N_];
    float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0};
    const int t_0 = b*T*C + h*_N_ + i;
    const int t_1 = t_0 + C;
    const int t_2 = t_0 + 2*C;
    const int t_T_1 = t_0 + (T-1)*C;
    for (int t = t_T_1; t > t_1; t -= C)
    {
        __syncthreads();
        gy[i] = float(_gy[t]);
        v[i] = float(_v[t-2*C]);
        __syncthreads();
        const float r = float(_r[t]);
        const float w = exp(_w[t-C]);
        float sum = 0.0f;
        #pragma unroll
        for (int j = 0; j < _N_; j++)
        {
            float& s = saaaa[j];
            float x = r * gy[j];
            s = (s + x) * w;
            sum += s * v[j];
        }
        sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]);
    }
    float sss = sbbbb[0];
    _gw[t_0] = 0;
    _gw[t_1] = F(sss * _w[t_1]);
    for (int t = t_2; t < t_T_1; t += C)
    {
        __syncthreads();
        gy[i] = float(_gy[t]);
        v[i] = float(_v[t-2*C]);
        __syncthreads();
        const float w = exp(_w[t-C]);
        const float k = float(_k[t-2*C]);
        float sum = 0.0f;
        #pragma unroll
        for (int j = 0; j < _N_; j++)
        {
            float& s = scccc[j];
            float x = k * v[j];
            s = (s + x) * w;
            sum += s * gy[j];
        }
        sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t]));
        _gw[t] = F(sss * _w[t]);
    }
    _gw[t_T_1] = 0;
}
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
{
    assert(H*_N_ == C);
    assert(_N_%4 == 0);
    kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
}
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
{
    assert(H*_N_ == C);
    assert(_N_%4 == 0);
    kernel_backward_111<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu);
    kernel_backward_222<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gw);
}
funasr/models/sense_voice/cuda/wkv6_op.cpp
New file
@@ -0,0 +1,23 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
 typedef at::BFloat16 bf16;
//typedef float bf16;
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
    cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
    cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &forward, "wkv6 forward");
    m.def("backward", &backward, "wkv6 backward");
}
TORCH_LIBRARY(wkv6, m) {
    m.def("forward", forward);
    m.def("backward", backward);
}
funasr/models/sense_voice/decoder.py
@@ -5,6 +5,32 @@
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )
def sense_voice_decode_forward(
    self,
@@ -38,10 +64,10 @@
    
    offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
    tgt, memory = x, xa
    tgt[tgt==-1] = 0
    tgt[tgt == -1] = 0
    tgt = (
        self.token_embedding(tgt)
        + self.positional_embedding[offset : offset + tgt.size(1)]
        + self.positional_embedding[offset: offset + tgt.size(1)]
    )
    # tgt = self.dropout(tgt)
    
@@ -54,13 +80,248 @@
    
    for layer, block in enumerate(self.blocks):
        x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
    x = self.ln(x)
    x = (
        x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
    ).float()
    
    return x
class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        is_pad_mask = kwargs.get("is_pad_mask", False)
        q = self.query(x)
        if kv_cache is None or xa is None or self.key not in kv_cache:
            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
            # otherwise, perform key/value projections for self- or cross-attention as usual.
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
            k = kv_cache[self.key]
            v = kv_cache[self.value]
        wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
        return self.out(wv), qk
    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
    ):
        is_pad_mask = kwargs.get("is_pad_mask", False)
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        qk = q @ k
        if mask is not None:
            if not is_pad_mask:
                qk = qk + mask[:n_ctx, :n_ctx]
            else:
                mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
                min_value = float(
                    np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
                )
                qk = qk.masked_fill(mask, min_value)
        qk = qk.float()
        w = F.softmax(qk, dim=-1).to(q.dtype)
        if mask is not None and is_pad_mask:
            w = w.masked_fill(mask, 0.0)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
from omegaconf import OmegaConf
class ResidualAttentionBlockRWKV(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, layer_id=0, **kwargs):
        super().__init__()
        rwkv_cfg = kwargs.get("rwkv_cfg", {})
        args = OmegaConf.create(rwkv_cfg)
        self.attn = RWKVLayer(args=args, layer_id=layer_id)
        if args.get("datatype", "bf16") == "bf16":
            self.attn.to(torch.bfloat16)
        self.ln0 = None
        if layer_id == 0 and not args.get("ln0", True):
            self.ln0 = LayerNorm(args.n_embd)
            if args.get("init_rwkv", True):
                print("init_rwkv")
                layer_id = 0
                scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
                nn.init.constant_(self.ln0.weight, scale)
        self.layer_id = layer_id
        self.args = args
        self.ln1 = None
        if not args.get("ln1", True):
            self.ln1 = LayerNorm(args.n_embd)
            # init
            if args.get("init_rwkv", True):
                print("init_rwkv")
                scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
                nn.init.constant_(self.ln1.weight, scale)
        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        is_pad_mask = kwargs.get("is_pad_mask", False)
        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
        if self.layer_id == 0 and self.ln0 is not None:
            x = self.ln0(x)
        if self.ln1 is None:
            x = x + self.attn(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
        else:
            x = x + self.attn(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x
@tables.register("decoder_classes", "SenseVoiceDecoder")
class SenseVoiceDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs
    ):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        self.blocks = nn.ModuleList(
            [
                ResidualAttentionBlockRWKV(n_state, n_head, cross_attention=True, layer_id=i, **kwargs)
                for i in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)
        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)
        self.use_padmask = kwargs.get("use_padmask", True)
    # def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
    #     """
    #     x : torch.LongTensor, shape = (batch_size, <= n_ctx)
    #         the text tokens
    #     xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
    #         the encoded audio features to be attended on
    #     """
    #     offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
    #     x = (
    #         self.token_embedding(x)
    #         + self.positional_embedding[offset: offset + x.shape[-1]]
    #     )
    #     x = x.to(xa.dtype)
    #
    #     for block in self.blocks:
    #         x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
    #
    #     x = self.ln(x)
    #     logits = (
    #         x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
    #     ).float()
    #
    #     return logits
    def forward(
        self,
        x: torch.Tensor,
        xa: torch.Tensor,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        """Forward decoder.
        Args:
            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
            hlens: (batch)
            ys_in_pad:
                input token ids, int64 (batch, maxlen_out)
                if input_layer == "embed"
                input tensor (batch, maxlen_out, #mels) in the other cases
            ys_in_lens: (batch)
        Returns:
            (tuple): tuple containing:
            x: decoded token score before softmax (batch, maxlen_out, token)
                if use_output_layer is True,
            olens: (batch, )
        """
        # import pdb;pdb.set_trace()
        use_padmask = self.use_padmask
        hlens = kwargs.get("hlens", None)
        ys_in_lens = kwargs.get("ys_in_lens", None)
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        tgt, memory = x, xa
        tgt[tgt == -1] = 0
        tgt = (
            self.token_embedding(tgt)
            + self.positional_embedding[offset: offset + tgt.size(1)]
        )
        # tgt = self.dropout(tgt)
        x = tgt.to(memory.dtype)
        if use_padmask and hlens is not None:
            memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
        else:
            memory_mask = None
        for layer, block in enumerate(self.blocks):
            x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
        x = self.ln(x)
        x = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()
        return x
funasr/models/sense_voice/model.py
@@ -226,4 +226,213 @@
        results.append(result_i)
    
        return results, meta_data
@tables.register("model_classes", "SenseVoiceRWKV")
class SenseVoiceRWKV(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        dims = kwargs.get("dims", {})
        dims = whisper.model.ModelDimensions(**dims)
        model = whisper.model.Whisper(dims=dims)
        # encoder
        model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
        model.encoder.use_padmask = kwargs.get("use_padmask", True)
        from .encoder import sense_voice_encode_forward
        model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
        # decoder
        del model.decoder
        decoder = kwargs.get("decoder", "SenseVoiceDecoder")
        decoder_class = tables.decoder_classes.get(decoder)
        decoder = decoder_class(n_vocab=dims.n_vocab,
                                n_ctx=dims.n_text_ctx,
                                n_state=dims.n_text_state,
                                n_head=dims.n_text_head,
                                n_layer=dims.n_text_layer,
                                **kwargs.get("decoder_conf"))
        model.decoder = decoder
        self.model = model
        self.encoder_output_size = self.model.dims.n_audio_state
        self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
        self.ignore_id = kwargs.get("ignore_id", -1)
        self.vocab_size = kwargs.get("vocab_size", -1)
        self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
        self.criterion_att = LabelSmoothingLoss(
            size=self.vocab_size,
            padding_idx=self.ignore_id,
            smoothing=kwargs.get("lsm_weight", 0.0),
            normalize_length=self.length_normalized_loss,
        )
        specaug = kwargs.get("specaug", None)
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**kwargs.get("specaug_conf", {}))
        self.specaug = specaug
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ):
        target_mask = kwargs.get("target_mask", None)
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]
        if self.activation_checkpoint:
            from torch.utils.checkpoint import checkpoint
            encoder_out, encoder_out_lens = checkpoint(self.encode, speech, speech_lengths, use_reentrant=False)
        else:
            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
            encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
        )
        loss = loss_att
        stats = {}
        stats["acc"] = acc_att
        stats["loss"] = torch.clone(loss.detach())
        stats["batch_size"] = batch_size
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ):
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
        # Forward encoder
        encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
        return encoder_out, encoder_out_lens
    def _calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
        **kwargs,
    ):
        target_mask = kwargs.get("target_mask", None)
        stats = {}
        # 1. Forward decoder
        decoder_out = self.model.decoder(
            x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
        )
        # 2. Compute attention loss
        mask = torch.ones_like(ys_pad) * (-1)
        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
        ys_pad_mask[ys_pad_mask == 0] = -1
        loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
        with torch.no_grad():
            preds = torch.argmax(decoder_out, -1)
            acc_att = compute_accuracy(preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id)
        return loss_att, acc_att, None, None
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  **kwargs,
                  ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        if frontend is None and not hasattr(self, "frontend"):
            frontend_class = tables.frontend_classes.get("WhisperFrontend")
            frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
            self.frontend = frontend
        else:
            frontend = frontend if frontend is not None else self.frontend
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in,
                                                            fs=frontend.fs if hasattr(frontend, "fs") else 16000,
                                                            audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
            lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
        speech = speech.to(device=kwargs["device"])[0, :, :]
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        DecodingOptions = kwargs.get("DecodingOptions", {})
        task = DecodingOptions.get("task", "ASR")
        if isinstance(task, str):
            task = [task]
        task = "".join([f"<|{x}|>" for x in task])
        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
        DecodingOptions["initial_prompt"] = initial_prompt
        language = DecodingOptions.get("language", None)
        language = None if language == "auto" else language
        DecodingOptions["language"] = language
        DecodingOptions["vocab_path"] = kwargs.get("vocab_path", None)
        if "without_timestamps" not in DecodingOptions:
            DecodingOptions["without_timestamps"] = True
        options = whisper.DecodingOptions(**DecodingOptions)
        result = whisper.decode(self.model, speech, options)
        text = f"{result.text}"
        results = []
        result_i = {"key": key[0], "text": text}
        results.append(result_i)
        return results, meta_data
funasr/models/sense_voice/rwkv_v6.py
New file
@@ -0,0 +1,425 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, math, gc, importlib
import torch
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
import torch.nn as nn
from torch.nn import functional as F
def __nop(ob):
    return ob
MyModule = nn.Module
MyFunction = __nop
if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
    MyModule = torch.jit.ScriptModule
    MyFunction = torch.jit.script_method
########################################################################################################
# CUDA Kernel
########################################################################################################
wkv6_cuda = None
def load_rwkv_kernel(HEAD_SIZE: int=64, RWKV_CTXLEN: int=512,):
    from torch.utils.cpp_extension import load
    global wkv6_cuda
    if wkv6_cuda is not None:
        return
    absolute_file_path = os.path.abspath(__file__)
    cur_dir = os.path.dirname(absolute_file_path)
    wkv6_cuda = load(name="wkv6", sources=[f"{cur_dir}/cuda/wkv6_op.cpp", f"{cur_dir}/cuda/wkv6_cuda.cu"],
                     verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3",
                                                      "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}",
                                                      f"-D_T_={RWKV_CTXLEN}"])
# dtype = torch.float
dtype = torch.bfloat16
class WKV_6(torch.autograd.Function):
    @staticmethod
    def forward(ctx, B, T, C, H, r, k, v, w, u):
        with torch.no_grad():
            # assert r.dtype == torch.bfloat16
            # assert k.dtype == torch.bfloat16
            # assert v.dtype == torch.bfloat16
            # assert w.dtype == torch.bfloat16
            # assert u.dtype == torch.bfloat16
            # assert HEAD_SIZE == C // H
            ctx.B = B
            ctx.T = T
            ctx.C = C
            ctx.H = H
            assert r.is_contiguous()
            assert k.is_contiguous()
            assert v.is_contiguous()
            assert w.is_contiguous()
            assert u.is_contiguous()
            ew = (-torch.exp(w.float())).contiguous()
            ctx.save_for_backward(r, k, v, ew, u)
            y = torch.empty((B, T, C), device=r.device, dtype=dtype,
                            memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
            wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y)
            return y
    @staticmethod
    def backward(ctx, gy):
        with torch.no_grad():
            # assert gy.dtype == torch.bfloat16
            B = ctx.B
            T = ctx.T
            C = ctx.C
            H = ctx.H
            assert gy.is_contiguous()
            r, k, v, ew, u = ctx.saved_tensors
            gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
                             memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
            gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
                             memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
            gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
                             memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
            gw = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
                             memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
            gu = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=dtype,
                             memory_format=torch.contiguous_format)  # .uniform_(-100, 100)
            wkv6_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu)
            gu = torch.sum(gu, 0).view(H, C // H)
            return (None, None, None, None, gr, gk, gv, gw, gu)
def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u):
    return WKV_6.apply(B, T, C, H, r, k, v, w, u)
class RWKV_Tmix_x060(MyModule):
    def __init__(self, args, layer_id):
        super().__init__()
        self.args = args
        load_rwkv_kernel(args.head_size_a, args.ctx_len)
        self.layer_id = layer_id
        self.head_size = args.head_size_a
        self.n_head = args.dim_att // self.head_size
        assert args.dim_att % self.n_head == 0
        with torch.no_grad():
            ratio_0_to_1 = layer_id / (args.n_layer - 1)  # 0 to 1
            ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
            ddd = torch.ones(1, 1, args.n_embd)
            for i in range(args.n_embd):
                ddd[0, 0, i] = i / args.n_embd
            # fancy time_mix
            self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
            self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
            self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
            self.time_maa_v = nn.Parameter(1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1))
            self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
            self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
            D_MIX_LORA = 32  # generate TIME_MIX for w,k,v,r,g
            self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
            self.time_maa_w2 = nn.Parameter(torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01))
            # fancy time_decay
            decay_speed = torch.ones(args.dim_att)
            for n in range(args.dim_att):
                decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
            self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
            D_DECAY_LORA = 64
            self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
            self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01))
            tmp = torch.zeros(args.dim_att)
            for n in range(args.dim_att):
                zigzag = ((n + 1) % 3 - 1) * 0.1
                tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
            self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
        self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
        self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
        self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
        self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
        self.ln_x = nn.GroupNorm(self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor ** 2))
    @MyFunction
    def jit_func(self, x):
        B, T, C = x.size()
        xx = self.time_shift(x) - x
        xxx = x + xx * self.time_maa_x
        xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
        xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
        mw, mk, mv, mr, mg = xxx.unbind(dim=0)
        xw = x + xx * (self.time_maa_w + mw)
        xk = x + xx * (self.time_maa_k + mk)
        xv = x + xx * (self.time_maa_v + mv)
        xr = x + xx * (self.time_maa_r + mr)
        xg = x + xx * (self.time_maa_g + mg)
        r = self.receptance(xr)
        k = self.key(xk)
        v = self.value(xv)
        g = F.silu(self.gate(xg))
        ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
        w = self.time_decay + ww
        return r, k, v, g, w
    @MyFunction
    def jit_func_2(self, x, g):
        B, T, C = x.size()
        x = x.view(B * T, C)
        x = self.ln_x(x).view(B, T, C)
        x = self.output(x * g)
        return x
    def forward(self, x):
        B, T, C = x.size()
        H = self.n_head
        r, k, v, g, w = self.jit_func(x)
        x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
        return self.jit_func_2(x, g)
class RWKV_CMix_x060(MyModule):
    def __init__(self, args, layer_id):
        super().__init__()
        self.args = args
        self.layer_id = layer_id
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        with torch.no_grad():  # fancy init of time_mix
            ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer)  # 1 to ~0
            ddd = torch.ones(1, 1, args.n_embd)
            for i in range(args.n_embd):
                ddd[0, 0, i] = i / args.n_embd
            self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
            self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
        self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
        self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
        self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
    @MyFunction
    def forward(self, x):
        xx = self.time_shift(x) - x
        xk = x + xx * self.time_maa_k
        xr = x + xx * self.time_maa_r
        k = self.key(xk)
        k = torch.relu(k) ** 2
        kv = self.value(k)
        return torch.sigmoid(self.receptance(xr)) * kv
class Block(nn.Module):
    def __init__(self, args, layer_id):
        super().__init__()
        self.args = args
        self.layer_id = layer_id
        self.ln1 = nn.LayerNorm(args.n_embd)
        self.ln2 = nn.LayerNorm(args.n_embd)
        if self.layer_id == 0:
            self.ln0 = nn.LayerNorm(args.n_embd)
        self.att = RWKV_Tmix_x060(args, layer_id)
        self.ffn = RWKV_CMix_x060(args, layer_id)
        if args.dropout > 0:
            self.drop0 = nn.Dropout(p=args.dropout)
            self.drop1 = nn.Dropout(p=args.dropout)
    def forward(self, x, x_emb=None):
        args = self.args
        B, T, C = x.size()
        if self.layer_id == 0:
            x = self.ln0(x)
        if self.args.dropout == 0:
            if self.layer_id == 0 and args.pre_ffn > 0:
                x = x + self.ffnPre(self.ln1(x))
            else:
                x = x + self.att(self.ln1(x))
            x = x + self.ffn(self.ln2(x))
        else:
            if self.layer_id == 0 and args.pre_ffn > 0:
                x = self.drop0(x + self.ffnPre(self.ln1(x)))
            else:
                x = self.drop0(x + self.att(self.ln1(x)))
            x = self.drop1(x + self.ffn(self.ln2(x)))
        return x
class RWKVLayer(nn.Module):
    def __init__(self, args, layer_id):
        super().__init__()
        self.args = args
        self.layer_id = layer_id
        if args.dim_ffn is None:
            args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
        self.ln0 = None
        if self.layer_id == 0 and args.get("ln0", True):
            self.ln0 = nn.LayerNorm(args.n_embd)
        self.ln1 = None
        if args.get("ln1", True):
            self.ln1 = nn.LayerNorm(args.n_embd)
        self.ln2 = nn.LayerNorm(args.n_embd)
        self.att = RWKV_Tmix_x060(args, layer_id)
        self.ffn = RWKV_CMix_x060(args, layer_id)
        if args.dropout > 0:
            self.drop0 = nn.Dropout(p=args.dropout)
            self.drop1 = nn.Dropout(p=args.dropout)
        # init
        if args.get("init_rwkv", True):
            print("init_rwkv")
            nn.init.orthogonal_(self.att.receptance.weight, gain=1)
            nn.init.orthogonal_(self.att.key.weight, gain=0.1)
            nn.init.orthogonal_(self.att.value.weight, gain=1)
            nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
            nn.init.zeros_(self.att.output.weight)
            nn.init.orthogonal_(self.ffn.key.weight, gain=1)
            nn.init.zeros_(self.ffn.value.weight)
            nn.init.zeros_(self.ffn.receptance.weight)
            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
            nn.init.constant_(self.ln2.weight, scale)
            if self.ln0 is not None:
                nn.init.constant_(self.ln0.weight, scale)
            if self.ln1 is not None:
                nn.init.constant_(self.ln1.weight, scale)
    def forward(self, x, x_emb=None, mask=None, **kwargs):
        args = self.args
        if args.get("datatype", "bf16") == "bf16":
            x = x.bfloat16()
        B, T, C = x.size()
        if self.layer_id == 0 and self.ln0 is not None:
            x = self.ln0(x)
        if self.args.dropout == 0:
            if self.ln1 is None:
                x = x + self.att(x)
            else:
                x = x + self.att(self.ln1(x))
            x = x + self.ffn(self.ln2(x))
        else:
            if self.ln1 is None:
                x = self.drop0(x + self.att(x))
            else:
                x = self.drop0(x + self.att(self.ln1(x)))
            x = self.drop1(x + self.ffn(self.ln2(x)))
        if args.get("datatype", "bf16") == "bf16":
            x = x.to(torch.float32)
        return x
class RWKV(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        if not hasattr(args, 'dim_att'):
            args.dim_att = args.n_embd
        if not hasattr(args, 'dim_ffn'):
            if '-f4' in os.environ["RWKV_MY_TESTING"]:
                args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
            else:
                args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)  # default = 3.5x emb size
        if not hasattr(args, 'tiny_att_layer'):
            args.tiny_att_layer = -1
        if not hasattr(args, 'tiny_att_dim'):
            args.tiny_att_dim = -1
        assert args.n_embd % 32 == 0
        assert args.dim_att % 32 == 0
        assert args.dim_ffn % 32 == 0
        self.emb = nn.Embedding(args.vocab_size, args.n_embd)
        self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
        self.ln_out = nn.LayerNorm(args.n_embd)
        self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
        if args.dropout > 0:
            self.drop0 = nn.Dropout(p=args.dropout)
    def forward(self, idx):
        args = self.args
        B, T = idx.size()
        assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
        x = self.emb(idx)
        x_emb = x
        if args.dropout > 0:
            x = self.drop0(x)
        if args.tiny_att_dim > 0:
            for block in self.blocks:
                if args.grad_cp == 1:
                    x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
                else:
                    x = block(x, x_emb)
        else:
            for block in self.blocks:
                if args.grad_cp == 1:
                    x = deepspeed.checkpointing.checkpoint(block, x)
                else:
                    x = block(x)
        x = self.ln_out(x)
        if args.head_qk > 0:
            q = self.head_q(x)[:, :T, :]
            k = self.head_k(x)[:, :T, :]
            c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
            c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
            if "32" in os.environ["RWKV_FLOAT_MODE"]:
                c = c @ F.one_hot(idx, num_classes=args.vocab_size)
            elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
                c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
            elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
                c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
            x = self.head(x) + c
        else:
            x = self.head(x)
        return x
funasr/models/sense_voice/whisper_lib/model.py
@@ -261,7 +261,9 @@
            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
        )
        all_heads[self.dims.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
        # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
        # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense()
        # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False)
    def set_alignment_heads(self, dump: bytes):
        array = np.frombuffer(