From 0a4a1d5257dace9561d95b38a9386539908dcd5e Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 23 四月 2024 12:48:52 +0800
Subject: [PATCH] Dev gzf exp (#1645)
---
funasr/models/conformer_rwkv/decoder.py | 507 ++++++++++++
funasr/models/sense_voice/decoder.py | 273 ++++++
funasr/datasets/audio_datasets/espnet_samplers.py | 11
funasr/models/conformer_rwkv/model.py | 19
funasr/models/sense_voice/rwkv_v6.py | 425 ++++++++++
funasr/models/conformer_rwkv/__init__.py | 0
funasr/datasets/audio_datasets/index_ds.py | 202 +++-
funasr/models/sense_voice/model.py | 211 +++++
funasr/models/conformer_rwkv/template.yaml | 123 +++
funasr/models/sense_voice/cuda/wkv5_op.cpp | 22
funasr/models/sense_voice/cuda/wkv6_op.cpp | 23
funasr/bin/train.py | 5
funasr/models/sense_voice/cuda/wkv5_cuda.cu | 202 +++++
examples/aishell/conformer/conf/conformer_rwkv.yaml | 124 +++
funasr/models/sense_voice/whisper_lib/model.py | 4
funasr/models/sense_voice/cuda/wkv6_cuda.cu | 243 ++++++
16 files changed, 2,321 insertions(+), 73 deletions(-)
diff --git a/examples/aishell/conformer/conf/conformer_rwkv.yaml b/examples/aishell/conformer/conf/conformer_rwkv.yaml
new file mode 100644
index 0000000..4742838
--- /dev/null
+++ b/examples/aishell/conformer/conf/conformer_rwkv.yaml
@@ -0,0 +1,124 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: Conformer
+model_conf:
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+
+# encoder
+encoder: ConformerEncoder
+encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+# decoder
+decoder: TransformerRWKVDecoder
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+ input_layer: embed
+ rwkv_cfg:
+ n_embd: 256
+ dropout: 0
+ head_size_a: 64
+ ctx_len: 512
+ dim_att: 256 #${model_conf.rwkv_cfg.n_embd}
+ dim_ffn: null
+ head_size_divisor: 4
+ n_layer: 6
+ pre_ffn: 0
+ ln0: false
+ ln1: false
+ init_rwkv: true
+
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+specaug: SpecAug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 150
+ keep_nbest_models: 10
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: EspnetStyleBatchSampler
+ batch_type: length # example or length
+ batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 1024
+ shuffle: True
+ num_workers: 4
+ preprocessor_speech: SpeechPreprocessSpeedPerturb
+ preprocessor_speech_conf:
+ speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+normalize: null
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 4ab2d8a..ab49c82 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -90,7 +90,8 @@
# freeze_param
freeze_param = kwargs.get("freeze_param", None)
if freeze_param is not None:
- freeze_param = eval(freeze_param)
+ if "," in freeze_param:
+ freeze_param = eval(freeze_param)
if isinstance(freeze_param, Sequence):
freeze_param = (freeze_param,)
logging.info("freeze_param is not None: %s", freeze_param)
@@ -104,7 +105,7 @@
if use_ddp:
model = model.cuda(local_rank)
model = DDP(model, device_ids=[local_rank],
- find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True))
+ find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
elif use_fsdp:
# model = FSDP(model).cuda(local_rank)
diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index bca0753..4bb34f3 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -32,8 +32,9 @@
def __init__(self, dataset,
batch_size,
batch_type="token",
- num_replicas=None,
rank=None,
+ num_replicas=None,
+ rank_split=False,
shuffle=True,
drop_last=False,
is_training: bool = True,
@@ -45,6 +46,10 @@
rank = dist.get_rank()
num_replicas = dist.get_world_size()
except:
+ rank = 0
+ num_replicas = 1
+ if rank_split:
+ logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank")
rank = 0
num_replicas = 1
self.rank = rank
@@ -65,8 +70,8 @@
self.length_scale_source = kwargs.get("length_scale_source", 1.0)
- super().__init__(dataset, num_replicas=num_replicas, rank=rank,
- shuffle=shuffle, drop_last=drop_last)
+ # super().__init__(dataset, num_replicas=num_replicas, rank=rank,
+ # shuffle=shuffle, drop_last=drop_last)
def __iter__(self):
if self.shuffle:
g = torch.Generator()
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
index 53419e8..3270531 100644
--- a/funasr/datasets/audio_datasets/index_ds.py
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -9,66 +9,66 @@
from funasr.register import tables
-@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
-class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
-
- def __init__(self, path):
- super().__init__()
-
- contents = []
- with open(path, encoding='utf-8') as fin:
- for line in fin:
- data = json.loads(line.strip())
- if "text" in data: # for sft
- self.contents.append(data['text'])
- if "source" in data: # for speech lab pretrain
- prompt = data["prompt"]
- source = data["source"]
- target = data["target"]
- source_len = data["source_len"]
- target_len = data["target_len"]
-
- contents.append({"source": source,
- "prompt": prompt,
- "target": target,
- "source_len": source_len,
- "target_len": target_len,
- }
- )
-
- self.contents = []
- total_num = len(contents)
- try:
- rank = dist.get_rank()
- world_size = dist.get_world_size()
- except:
- rank = 0
- world_size = 1
- logging.warning("distributed is not initialized, only single shard")
- num_per_rank = total_num // world_size
-
- # rank = 0
- # import ipdb; ipdb.set_trace()
- self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
-
- logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
-
- def __len__(self):
- return len(self.contents)
-
- def __getitem__(self, index):
- try:
- data = self.contents[index]
- except:
- print(index)
- return data
-
- def get_source_len(self, data_dict):
- return data_dict["source_len"]
-
- def get_target_len(self, data_dict):
-
- return data_dict["target_len"] if "target_len" in data_dict else 0
+# @tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
+# class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
+#
+# def __init__(self, path):
+# super().__init__()
+#
+# contents = []
+# with open(path, encoding='utf-8') as fin:
+# for line in fin:
+# data = json.loads(line.strip())
+# if "text" in data: # for sft
+# self.contents.append(data['text'])
+# if "source" in data: # for speech lab pretrain
+# prompt = data["prompt"]
+# source = data["source"]
+# target = data["target"]
+# source_len = data["source_len"]
+# target_len = data["target_len"]
+#
+# contents.append({"source": source,
+# "prompt": prompt,
+# "target": target,
+# "source_len": source_len,
+# "target_len": target_len,
+# }
+# )
+#
+# self.contents = []
+# total_num = len(contents)
+# try:
+# rank = dist.get_rank()
+# world_size = dist.get_world_size()
+# except:
+# rank = 0
+# world_size = 1
+# logging.warning("distributed is not initialized, only single shard")
+# num_per_rank = total_num // world_size
+#
+# # rank = 0
+# # import ipdb; ipdb.set_trace()
+# self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
+#
+# logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
+#
+# def __len__(self):
+# return len(self.contents)
+#
+# def __getitem__(self, index):
+# try:
+# data = self.contents[index]
+# except:
+# print(index)
+# return data
+#
+# def get_source_len(self, data_dict):
+# return data_dict["source_len"]
+#
+# def get_target_len(self, data_dict):
+#
+# return data_dict["target_len"] if "target_len" in data_dict else 0
@tables.register("index_ds_classes", "IndexDSJsonl")
@tables.register("index_ds_classes", "IndexDSJsonlRankFull")
@@ -143,3 +143,85 @@
def get_target_len(self, data_dict):
return data_dict.get("target_len", 0)
+
+
+@tables.register("index_ds_classes", "IndexDSJsonlRankSplit")
+class IndexDSJsonlRankSplit(torch.utils.data.Dataset):
+
+ def __init__(self, path: str, **kwargs):
+ super().__init__()
+ self.max_source_length = kwargs.get("max_source_length", 2048)
+ self.min_source_length = kwargs.get("min_source_length", 0)
+ self.max_target_length = kwargs.get("max_target_length", 2048)
+ self.min_target_length = kwargs.get("min_target_length", 0)
+
+ with open(path, encoding='utf-8') as fin:
+ file_list = fin.readlines()
+
+ total_num = len(file_list)
+ try:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ except:
+ rank = 0
+ world_size = 1
+ logging.warning("distributed is not initialized, only single shard")
+ num_per_rank = total_num // world_size
+ if num_per_rank * world_size < total_num:
+ logging.warning(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}")
+
+ file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank]
+
+ contents = []
+ for file_json in file_list_rank:
+
+ with open(file_json.strip(), encoding='utf-8') as fin:
+ for line in fin:
+ data = json.loads(line.strip())
+ if "text" in data: # for sft
+ contents.append(data['text'])
+ if "source" in data: # for speech lab pretrain
+ prompt = data.get("prompt", "<ASR>")
+ source = data["source"].replace("/cpfs01", "/cpfs_speech/data")
+ target = data["target"]
+ source_len = data.get("source_len", 1)
+ target_len = data.get("target_len", 0)
+
+ if source_len < self.min_source_length or source_len > self.max_source_length:
+ continue
+ if target_len < self.min_target_length or target_len > self.max_target_length:
+ continue
+ contents_i = {"source": source,
+ "prompt": prompt,
+ "target": target,
+ "source_len": source_len,
+ "target_len": target_len,
+ }
+ text_language = data.get("text_language", None)
+ if text_language is not None:
+ contents_i["text_language"] = text_language
+ audio_language = data.get("audio_language", None)
+ if audio_language is not None:
+ contents_i["audio_language"] = audio_language
+ contents.append(contents_i)
+
+ self.contents = contents
+
+ logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}")
+
+ def __len__(self):
+ return len(self.contents)
+
+ def __getitem__(self, index):
+ try:
+ data = self.contents[index]
+ except:
+ print(index)
+ return data
+
+ def get_source_len(self, data_dict):
+ return data_dict.get("source_len", 1)
+
+ def get_target_len(self, data_dict):
+
+ return data_dict.get("target_len", 0)
diff --git a/funasr/models/conformer_rwkv/__init__.py b/funasr/models/conformer_rwkv/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/conformer_rwkv/__init__.py
diff --git a/funasr/models/conformer_rwkv/decoder.py b/funasr/models/conformer_rwkv/decoder.py
new file mode 100644
index 0000000..d7f113d
--- /dev/null
+++ b/funasr/models/conformer_rwkv/decoder.py
@@ -0,0 +1,507 @@
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Decoder definition."""
+from typing import Any
+from typing import List
+from typing import Sequence
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
+from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.lightconv import LightweightConvolution
+from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
+from funasr.models.transformer.utils.mask import subsequent_mask
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.positionwise_feed_forward import (
+ PositionwiseFeedForward, # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
+from omegaconf import OmegaConf
+from funasr.register import tables
+
+class DecoderLayer(nn.Module):
+ """Single decoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ src_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ layer_id=None,
+ args={},
+ ):
+ """Construct an DecoderLayer object."""
+ super(DecoderLayer, self).__init__()
+ self.size = size
+ self.self_attn = self_attn.to(torch.bfloat16)
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ self.norm2 = LayerNorm(size)
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+ self.layer_id = layer_id
+ self.ln0 = None
+ if self.layer_id == 0 and not args.get("ln0", True):
+ self.ln0 = LayerNorm(args.n_embd)
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ layer_id = 0
+ scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+ nn.init.constant_(self.ln0.weight, scale)
+
+ # init
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+ nn.init.constant_(self.norm1.weight, scale)
+ nn.init.constant_(self.self_attn.ln2.weight, scale)
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+
+ if self.layer_id == 0 and self.ln0 is not None:
+ tgt = self.ln0(tgt)
+
+ residual = tgt
+
+
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+
+ x = residual + self.dropout(self.self_attn(tgt, mask=tgt_mask))
+ else:
+
+ # tgt_q = tgt[:, -1:, :]
+ # residual_q = residual[:, -1:, :]
+ tgt_q_mask = None
+
+ x = residual + self.dropout(self.self_attn(tgt, mask=tgt_q_mask))
+ x = x[:, -1, :]
+
+
+
+ # x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+
+
+ residual = x
+ x = self.norm2(x)
+ x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
+ residual = x
+ x = self.norm3(x)
+ x = residual + self.dropout(self.feed_forward(x))
+
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, memory, memory_mask
+
+
+
+class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
+ """Base class of Transfomer decoder module.
+
+ Args:
+ vocab_size: output dim
+ encoder_output_size: dimension of attention
+ attention_heads: the number of heads of multi head attention
+ linear_units: the number of units of position-wise feed forward
+ num_blocks: the number of decoder blocks
+ dropout_rate: dropout rate
+ self_attention_dropout_rate: dropout rate for attention
+ input_layer: input layer type
+ use_output_layer: whether to use output layer
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+ normalize_before: whether to use layer_norm before the first block
+ concat_after: whether to concat attention layer's input and output
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied.
+ i.e. x -> x + att(x)
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ ):
+ super().__init__()
+ attention_dim = encoder_output_size
+
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ elif input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(vocab_size, attention_dim),
+ torch.nn.LayerNorm(attention_dim),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ else:
+ raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.after_norm = LayerNorm(attention_dim)
+ if use_output_layer:
+ self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.output_layer = None
+
+ # Must set by the inheritance
+ self.decoders = None
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ tgt = ys_in_pad
+ # tgt_mask: (B, 1, L)
+ tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+ # m: (1, L, L)
+ m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+ # tgt_mask: (B, L, L)
+ tgt_mask = tgt_mask & m
+
+ memory = hs_pad
+ memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
+ memory.device
+ )
+ # Padding for Longformer
+ if memory_mask.shape[-1] != memory.shape[1]:
+ padlen = memory.shape[1] - memory_mask.shape[-1]
+ memory_mask = torch.nn.functional.pad(
+ memory_mask, (0, padlen), "constant", False
+ )
+
+ x = self.embed(tgt)
+ x, tgt_mask, memory, memory_mask = self.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.output_layer is not None:
+ x = self.output_layer(x)
+
+ olens = tgt_mask.sum(1)
+ return x, olens
+
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+
+ Args:
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+ x = self.embed(tgt)
+ if cache is None:
+ cache = [None] * len(self.decoders)
+ new_cache = []
+ for c, decoder in zip(cache, self.decoders):
+ x, tgt_mask, memory, memory_mask = decoder(
+ x, tgt_mask, memory, None, cache=c
+ )
+ new_cache.append(x)
+
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.output_layer is not None:
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
+
+ return y, new_cache
+
+ def score(self, ys, state, x):
+ """Score."""
+ ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
+ logp, state = self.forward_one_step(
+ ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
+ )
+ return logp.squeeze(0), state
+
+ def batch_score(
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+ ) -> Tuple[torch.Tensor, List[Any]]:
+ """Score new token batch.
+
+ Args:
+ ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+ states (List[Any]): Scorer states for prefix tokens.
+ xs (torch.Tensor):
+ The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+ Returns:
+ tuple[torch.Tensor, List[Any]]: Tuple of
+ batchfied scores for next token with shape of `(n_batch, n_vocab)`
+ and next state list for ys.
+
+ """
+ # merge states
+ n_batch = len(ys)
+ n_layers = len(self.decoders)
+ if states[0] is None:
+ batch_state = None
+ else:
+ # transpose state of [batch, layer] into [layer, batch]
+ batch_state = [
+ torch.stack([states[b][i] for b in range(n_batch)])
+ for i in range(n_layers)
+ ]
+
+ # batch decoding
+ ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
+ logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
+
+ # transpose state of [layer, batch] into [batch, layer]
+ state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
+ return logp, state_list
+
+@tables.register("decoder_classes", "TransformerRWKVDecoder")
+class TransformerRWKVDecoder(BaseTransformerDecoder):
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_output_layer=use_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+ from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+ rwkv_cfg = kwargs.get("rwkv_cfg", {})
+ args = OmegaConf.create(rwkv_cfg)
+ # self.attn = RWKVLayer(args=args, layer_id=layer_id)
+ attention_dim = encoder_output_size
+ self.decoders = repeat(
+ num_blocks,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ RWKVLayer(args=args, layer_id=lnum),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ lnum,
+ args=args,
+ ),
+ )
+
+ # init
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ nn.init.uniform_(self.embed[0].weight, a=-1e-4, b=1e-4)
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ tgt = ys_in_pad
+ # tgt_mask: (B, 1, L)
+ tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+ # m: (1, L, L)
+ m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+ # tgt_mask: (B, L, L)
+ tgt_mask = tgt_mask & m
+
+ memory = hs_pad
+ memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
+ memory.device
+ )
+ # Padding for Longformer
+ if memory_mask.shape[-1] != memory.shape[1]:
+ padlen = memory.shape[1] - memory_mask.shape[-1]
+ memory_mask = torch.nn.functional.pad(
+ memory_mask, (0, padlen), "constant", False
+ )
+
+ x = self.embed(tgt)
+ x, tgt_mask, memory, memory_mask = self.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.output_layer is not None:
+ x = self.output_layer(x)
+
+ olens = tgt_mask.sum(1)
+ return x, olens
+
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+ """Forward one step.
+
+ Args:
+ tgt: input token ids, int64 (batch, maxlen_out)
+ tgt_mask: input token mask, (batch, maxlen_out)
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+ memory: encoded memory, float32 (batch, maxlen_in, feat)
+ cache: cached output list of (batch, max_time_out-1, size)
+ Returns:
+ y, cache: NN output value and cache per `self.decoders`.
+ y.shape` is (batch, maxlen_out, token)
+ """
+ x = self.embed(tgt)
+ if cache is None:
+ cache = [None] * len(self.decoders)
+ new_cache = []
+ for c, decoder in zip(cache, self.decoders):
+ x, tgt_mask, memory, memory_mask = decoder(
+ x, tgt_mask, memory, None, cache=c
+ )
+ new_cache.append(x)
+
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.output_layer is not None:
+ y = torch.log_softmax(self.output_layer(y), dim=-1)
+
+ return y, new_cache
\ No newline at end of file
diff --git a/funasr/models/conformer_rwkv/model.py b/funasr/models/conformer_rwkv/model.py
new file mode 100644
index 0000000..171014b
--- /dev/null
+++ b/funasr/models/conformer_rwkv/model.py
@@ -0,0 +1,19 @@
+import logging
+
+import torch
+
+from funasr.models.transformer.model import Transformer
+from funasr.register import tables
+
+@tables.register("model_classes", "Conformer")
+class Conformer(Transformer):
+ """CTC-attention hybrid Encoder-Decoder model"""
+
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+
+ super().__init__(*args, **kwargs)
diff --git a/funasr/models/conformer_rwkv/template.yaml b/funasr/models/conformer_rwkv/template.yaml
new file mode 100644
index 0000000..cd71105
--- /dev/null
+++ b/funasr/models/conformer_rwkv/template.yaml
@@ -0,0 +1,123 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: Conformer
+model_conf:
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+
+# encoder
+encoder: ConformerEncoder
+encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 2048 # the number of units of position-wise feed forward
+ num_blocks: 12 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.0
+ input_layer: conv2d # encoder architecture type
+ normalize_before: true
+ pos_enc_layer_type: rel_pos
+ selfattention_layer_type: rel_selfattn
+ activation_type: swish
+ macaron_style: true
+ use_cnn_module: true
+ cnn_module_kernel: 15
+
+# decoder
+decoder: TransformerRWKVDecoder
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+ input_layer: embed
+ rwkv_cfg:
+ n_embd: 256
+ dropout: 0
+ head_size_a: 64
+ ctx_len: 512
+ dim_att: 256 #${model_conf.rwkv_cfg.n_embd}
+ dim_ffn: null
+ head_size_divisor: 4
+ n_layer: 6
+ pre_ffn: 0
+ ln0: false
+ ln1: false
+
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+specaug: SpecAug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 150
+ keep_nbest_models: 10
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: EspnetStyleBatchSampler
+ batch_type: length # example or length
+ batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 1024
+ shuffle: True
+ num_workers: 4
+ preprocessor_speech: SpeechPreprocessSpeedPerturb
+ preprocessor_speech_conf:
+ speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+normalize: null
diff --git a/funasr/models/sense_voice/cuda/wkv5_cuda.cu b/funasr/models/sense_voice/cuda/wkv5_cuda.cu
new file mode 100644
index 0000000..3e6b859
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv5_cuda.cu
@@ -0,0 +1,202 @@
+#include <stdio.h>
+#include <assert.h>
+#include "ATen/ATen.h"
+typedef at::BFloat16 bf16;
+
+template <typename F>
+__global__ void kernel_forward(const int B, const int T, const int C, const int H,
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
+ F *__restrict__ const _y)
+{
+ const int b = blockIdx.x / H;
+ const int h = blockIdx.x % H;
+ const int i = threadIdx.x;
+ _w += h*_N_;
+ _u += h*_N_;
+
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
+ float state[_N_] = {0};
+
+ __syncthreads();
+ w[i] = _w[i];
+ u[i] = float(_u[i]);
+ __syncthreads();
+
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
+ {
+ __syncthreads();
+ r[i] = float(_r[t]);
+ k[i] = float(_k[t]);
+ __syncthreads();
+
+ const float v = float(_v[t]);
+ float y = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j+=4)
+ {
+ const float4& r_ = (float4&)(r[j]);
+ const float4& k_ = (float4&)(k[j]);
+ const float4& w_ = (float4&)(w[j]);
+ const float4& u_ = (float4&)(u[j]);
+ float4& s = (float4&)(state[j]);
+ float4 x;
+
+ x.x = k_.x * v;
+ x.y = k_.y * v;
+ x.z = k_.z * v;
+ x.w = k_.w * v;
+
+ y += r_.x * (u_.x * x.x + s.x);
+ y += r_.y * (u_.y * x.y + s.y);
+ y += r_.z * (u_.z * x.z + s.z);
+ y += r_.w * (u_.w * x.w + s.w);
+
+ s.x = s.x * w_.x + x.x;
+ s.y = s.y * w_.y + x.y;
+ s.z = s.z * w_.z + x.z;
+ s.w = s.w * w_.w + x.w;
+ }
+ _y[t] = F(y);
+ }
+}
+
+template <typename F>
+__global__ void kernel_backward(const int B, const int T, const int C, const int H,
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
+ F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
+{
+ const int b = blockIdx.x / H;
+ const int h = blockIdx.x % H;
+ const int i = threadIdx.x;
+ _w += h*_N_;
+ _u += h*_N_;
+ __w += h*_N_;
+
+ __shared__ float w_[_N_], u_[_N_];
+ __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_];
+ __syncthreads();
+ w_[i] = _w[i];
+ u_[i] = float(_u[i]);
+ __syncthreads();
+
+ const float w = w_[i];
+ const float ww = __w[i];
+ const float u = u_[i];
+
+ float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
+
+ float gw = 0, gu = 0;
+ const int t000 = b*T*C + h*_N_ + i;
+ const int t111 = (b+1)*T*C + h*_N_ + i;
+ const int t222 = t111 - 2*C;
+
+ for (int t = t000; t < t111; t += C)
+ {
+ __syncthreads();
+ v[i] = float(_v[t]);
+ gy[i] = float(_gy[t]);
+ __syncthreads();
+
+ const float k = float(_k[t]);
+ float gr = 0, gu_ = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = state[j];
+ float x = k * v[j];
+
+ gr += (u * x + s) * gy[j];
+ gu_ += x * gy[j];
+ s = s * w + x;
+ }
+ _gr[t] = F(gr);
+ gu += float(_r[t]) * gu_;
+ }
+ _gu[b*C + h*_N_ + i] = F(gu);
+
+ for (int t = t000; t < t222; t += C)
+ {
+ __syncthreads();
+ v[i] = float(_v[t]);
+ gy[i] = float(_gy[t + 2*C]);
+ __syncthreads();
+
+ const float k = float(_k[t]);
+ float gw_ = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = saaaa[j];
+ float& s2 = sbbbb[j];
+ float x = k * v[j];
+
+ float tmp = w * (x + s);
+ s = tmp;
+ s2 = tmp + w * s2;
+ gw_ += s2 * gy[j];
+ }
+ gw += float(_r[t + 2*C]) * gw_;
+ }
+ _gw[b*C + h*_N_ + i] = F(ww * gw);
+
+ for (int t = t111 - C; t >= t000; t -= C)
+ {
+ __syncthreads();
+ v[i] = float(_v[t]);
+ gy[i] = float(_gy[t]);
+ __syncthreads();
+
+ const float rr = float(_r[t]);
+ float gk = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = scccc[j];
+ float x = rr * gy[j];
+
+ gk += (u * x + s) * v[j];
+ s = x + s * w;
+ }
+ _gk[t] = F(gk);
+ }
+
+ for (int t = t111 - C; t >= t000; t -= C)
+ {
+ __syncthreads();
+ r[i] = float(_r[t]);
+ k[i] = float(_k[t]);
+ __syncthreads();
+
+ const float gyy = float(_gy[t]);
+ float gv = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = sdddd[j];
+ float x = gyy * r[j];
+
+ gv += (u_[j] * x + s) * k[j];
+ s = x + s * w_[j];
+ }
+ _gv[t] = F(gv);
+ }
+}
+
+void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
+{
+ assert(H*_N_ == C);
+ assert(_N_%4 == 0);
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
+}
+
+void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
+{
+ assert(H*_N_ == C);
+ assert(_N_%4 == 0);
+ kernel_backward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
+}
diff --git a/funasr/models/sense_voice/cuda/wkv5_op.cpp b/funasr/models/sense_voice/cuda/wkv5_op.cpp
new file mode 100644
index 0000000..4c9ece1
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv5_op.cpp
@@ -0,0 +1,22 @@
+#include <torch/extension.h>
+#include "ATen/ATen.h"
+typedef at::BFloat16 bf16;
+
+void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
+void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
+
+void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
+ cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
+}
+void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
+ cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
+}
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward", &forward, "wkv5 forward");
+ m.def("backward", &backward, "wkv5 backward");
+}
+
+TORCH_LIBRARY(wkv5, m) {
+ m.def("forward", forward);
+ m.def("backward", backward);
+}
diff --git a/funasr/models/sense_voice/cuda/wkv6_cuda.cu b/funasr/models/sense_voice/cuda/wkv6_cuda.cu
new file mode 100644
index 0000000..d98f57f
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv6_cuda.cu
@@ -0,0 +1,243 @@
+#include <stdio.h>
+#include <assert.h>
+#include "ATen/ATen.h"
+typedef at::BFloat16 bf16;
+// typedef float bf16;
+
+template <typename F>
+__global__ void kernel_forward(const int B, const int T, const int C, const int H,
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
+ F *__restrict__ const _y)
+{
+ const int b = blockIdx.x / H;
+ const int h = blockIdx.x % H;
+ const int i = threadIdx.x;
+ _u += h*_N_;
+
+ __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
+ float state[_N_] = {0};
+
+ __syncthreads();
+ u[i] = float(_u[i]);
+ __syncthreads();
+
+ for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
+ {
+ __syncthreads();
+ w[i] = exp(_w[t]);
+ r[i] = float(_r[t]);
+ k[i] = float(_k[t]);
+ __syncthreads();
+
+ const float v = float(_v[t]);
+ float y = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j+=4)
+ {
+ const float4& r_ = (float4&)(r[j]);
+ const float4& k_ = (float4&)(k[j]);
+ const float4& w_ = (float4&)(w[j]);
+ const float4& u_ = (float4&)(u[j]);
+ float4& s = (float4&)(state[j]);
+ float4 x;
+
+ x.x = k_.x * v;
+ x.y = k_.y * v;
+ x.z = k_.z * v;
+ x.w = k_.w * v;
+
+ y += r_.x * (u_.x * x.x + s.x);
+ y += r_.y * (u_.y * x.y + s.y);
+ y += r_.z * (u_.z * x.z + s.z);
+ y += r_.w * (u_.w * x.w + s.w);
+
+ s.x = s.x * w_.x + x.x;
+ s.y = s.y * w_.y + x.y;
+ s.z = s.z * w_.z + x.z;
+ s.w = s.w * w_.w + x.w;
+ }
+ _y[t] = F(y);
+ }
+}
+
+template <typename F>
+__global__ void kernel_backward_111(const int B, const int T, const int C, const int H,
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
+ F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu)
+{
+ const int b = blockIdx.x / H;
+ const int h = blockIdx.x % H;
+ const int i = threadIdx.x;
+ _u += h*_N_;
+
+ __shared__ float u_[_N_];
+ __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_];
+ __syncthreads();
+ u_[i] = float(_u[i]);
+ __syncthreads();
+
+ const float u = u_[i];
+
+ float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
+
+ const int t_0 = b*T*C + h*_N_ + i;
+ const int t_T_1 = t_0 + (T-1)*C;
+ const int t_T = t_0 + T*C;
+
+ float gu = 0;
+ for (int t = t_0; t < t_T; t += C)
+ {
+ __syncthreads();
+ v[i] = float(_v[t]);
+ gy[i] = float(_gy[t]);
+ __syncthreads();
+
+ const float k = float(_k[t]);
+ const float w = exp(_w[t]);
+ float gr = 0, gu_ = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = state[j];
+ float x = k * v[j];
+
+ gr += (u * x + s) * gy[j];
+ gu_ += x * gy[j];
+ s = s * w + x;
+ }
+ _gr[t] = F(gr);
+ gu += float(_r[t]) * gu_;
+ }
+ _gu[b*C + h*_N_ + i] = F(gu);
+
+ for (int t = t_T_1; t >= t_0; t -= C)
+ {
+ __syncthreads();
+ v[i] = float(_v[t]);
+ gy[i] = float(_gy[t]);
+ __syncthreads();
+
+ const float rr = float(_r[t]);
+ const float w = exp(_w[t]);
+ float gk = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = scccc[j];
+ float x = rr * gy[j];
+
+ gk += (u * x + s) * v[j];
+ s = x + s * w;
+ }
+ _gk[t] = F(gk);
+ }
+
+ for (int t = t_T_1; t >= t_0; t -= C)
+ {
+ __syncthreads();
+ r[i] = float(_r[t]);
+ k[i] = float(_k[t]);
+ w_[i] = exp(_w[t]);
+ __syncthreads();
+
+ const float gyy = float(_gy[t]);
+ float gv = 0;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = sdddd[j];
+ float x = gyy * r[j];
+
+ gv += (u_[j] * x + s) * k[j];
+ s = x + s * w_[j];
+ }
+ _gv[t] = F(gv);
+ }
+}
+
+template <typename F>
+__global__ void kernel_backward_222(const int B, const int T, const int C, const int H,
+ const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
+ F *__restrict__ const _gw)
+{
+ const int b = blockIdx.x / H;
+ const int h = blockIdx.x % H;
+ const int i = threadIdx.x;
+
+ __shared__ float v[_N_], gy[_N_];
+ float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0};
+
+ const int t_0 = b*T*C + h*_N_ + i;
+ const int t_1 = t_0 + C;
+ const int t_2 = t_0 + 2*C;
+ const int t_T_1 = t_0 + (T-1)*C;
+
+ for (int t = t_T_1; t > t_1; t -= C)
+ {
+ __syncthreads();
+ gy[i] = float(_gy[t]);
+ v[i] = float(_v[t-2*C]);
+ __syncthreads();
+
+ const float r = float(_r[t]);
+ const float w = exp(_w[t-C]);
+ float sum = 0.0f;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = saaaa[j];
+ float x = r * gy[j];
+ s = (s + x) * w;
+ sum += s * v[j];
+ }
+ sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]);
+ }
+
+ float sss = sbbbb[0];
+ _gw[t_0] = 0;
+ _gw[t_1] = F(sss * _w[t_1]);
+
+ for (int t = t_2; t < t_T_1; t += C)
+ {
+ __syncthreads();
+ gy[i] = float(_gy[t]);
+ v[i] = float(_v[t-2*C]);
+ __syncthreads();
+
+ const float w = exp(_w[t-C]);
+ const float k = float(_k[t-2*C]);
+ float sum = 0.0f;
+
+ #pragma unroll
+ for (int j = 0; j < _N_; j++)
+ {
+ float& s = scccc[j];
+ float x = k * v[j];
+ s = (s + x) * w;
+ sum += s * gy[j];
+ }
+ sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t]));
+ _gw[t] = F(sss * _w[t]);
+ }
+ _gw[t_T_1] = 0;
+}
+
+void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
+{
+ assert(H*_N_ == C);
+ assert(_N_%4 == 0);
+ kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
+}
+
+void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
+{
+ assert(H*_N_ == C);
+ assert(_N_%4 == 0);
+ kernel_backward_111<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu);
+ kernel_backward_222<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gw);
+}
diff --git a/funasr/models/sense_voice/cuda/wkv6_op.cpp b/funasr/models/sense_voice/cuda/wkv6_op.cpp
new file mode 100644
index 0000000..22da520
--- /dev/null
+++ b/funasr/models/sense_voice/cuda/wkv6_op.cpp
@@ -0,0 +1,23 @@
+#include <torch/extension.h>
+#include "ATen/ATen.h"
+ typedef at::BFloat16 bf16;
+//typedef float bf16;
+
+void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
+void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
+
+void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
+ cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
+}
+void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
+ cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
+}
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("forward", &forward, "wkv6 forward");
+ m.def("backward", &backward, "wkv6 backward");
+}
+
+TORCH_LIBRARY(wkv6, m) {
+ m.def("forward", forward);
+ m.def("backward", backward);
+}
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index bae2832..9087ea1 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -5,6 +5,32 @@
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.register import tables
+import base64
+import gzip
+from dataclasses import dataclass
+from typing import Dict, Iterable, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class LayerNorm(nn.LayerNorm):
+ def forward(self, x: Tensor) -> Tensor:
+ return super().forward(x.float()).type(x.dtype)
+
+
+class Linear(nn.Linear):
+ def forward(self, x: Tensor) -> Tensor:
+ return F.linear(
+ x,
+ self.weight.to(x.dtype),
+ None if self.bias is None else self.bias.to(x.dtype),
+ )
+
+
def sense_voice_decode_forward(
self,
@@ -38,10 +64,10 @@
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
tgt, memory = x, xa
- tgt[tgt==-1] = 0
+ tgt[tgt == -1] = 0
tgt = (
self.token_embedding(tgt)
- + self.positional_embedding[offset : offset + tgt.size(1)]
+ + self.positional_embedding[offset: offset + tgt.size(1)]
)
# tgt = self.dropout(tgt)
@@ -54,13 +80,248 @@
for layer, block in enumerate(self.blocks):
x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
-
-
+
x = self.ln(x)
x = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
-
return x
-
\ No newline at end of file
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, n_state: int, n_head: int):
+ super().__init__()
+ self.n_head = n_head
+ self.query = Linear(n_state, n_state)
+ self.key = Linear(n_state, n_state, bias=False)
+ self.value = Linear(n_state, n_state)
+ self.out = Linear(n_state, n_state)
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ **kwargs,
+ ):
+ is_pad_mask = kwargs.get("is_pad_mask", False)
+
+ q = self.query(x)
+
+ if kv_cache is None or xa is None or self.key not in kv_cache:
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
+ k = self.key(x if xa is None else xa)
+ v = self.value(x if xa is None else xa)
+ else:
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
+ k = kv_cache[self.key]
+ v = kv_cache[self.value]
+
+ wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
+ return self.out(wv), qk
+
+ def qkv_attention(
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
+ ):
+ is_pad_mask = kwargs.get("is_pad_mask", False)
+ n_batch, n_ctx, n_state = q.shape
+ scale = (n_state // self.n_head) ** -0.25
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
+
+ qk = q @ k
+ if mask is not None:
+ if not is_pad_mask:
+ qk = qk + mask[:n_ctx, :n_ctx]
+ else:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = float(
+ np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
+ )
+ qk = qk.masked_fill(mask, min_value)
+
+ qk = qk.float()
+
+ w = F.softmax(qk, dim=-1).to(q.dtype)
+ if mask is not None and is_pad_mask:
+ w = w.masked_fill(mask, 0.0)
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
+
+
+from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+from omegaconf import OmegaConf
+
+
+class ResidualAttentionBlockRWKV(nn.Module):
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, layer_id=0, **kwargs):
+ super().__init__()
+
+ rwkv_cfg = kwargs.get("rwkv_cfg", {})
+ args = OmegaConf.create(rwkv_cfg)
+ self.attn = RWKVLayer(args=args, layer_id=layer_id)
+ if args.get("datatype", "bf16") == "bf16":
+ self.attn.to(torch.bfloat16)
+
+ self.ln0 = None
+ if layer_id == 0 and not args.get("ln0", True):
+ self.ln0 = LayerNorm(args.n_embd)
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ layer_id = 0
+ scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+ nn.init.constant_(self.ln0.weight, scale)
+ self.layer_id = layer_id
+ self.args = args
+
+ self.ln1 = None
+ if not args.get("ln1", True):
+ self.ln1 = LayerNorm(args.n_embd)
+ # init
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+ nn.init.constant_(self.ln1.weight, scale)
+
+ self.cross_attn = (
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
+ )
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+
+ n_mlp = n_state * 4
+ self.mlp = nn.Sequential(
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
+ )
+ self.mlp_ln = LayerNorm(n_state)
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ **kwargs,
+ ):
+ is_pad_mask = kwargs.get("is_pad_mask", False)
+ is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
+
+ if self.layer_id == 0 and self.ln0 is not None:
+ x = self.ln0(x)
+
+ if self.ln1 is None:
+ x = x + self.attn(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+ else:
+ x = x + self.attn(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+
+ if self.cross_attn:
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
+ x = x + self.mlp(self.mlp_ln(x))
+
+ return x
+
+@tables.register("decoder_classes", "SenseVoiceDecoder")
+class SenseVoiceDecoder(nn.Module):
+ def __init__(
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs
+ ):
+ super().__init__()
+
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+
+
+ self.blocks = nn.ModuleList(
+ [
+ ResidualAttentionBlockRWKV(n_state, n_head, cross_attention=True, layer_id=i, **kwargs)
+ for i in range(n_layer)
+ ]
+ )
+ self.ln = LayerNorm(n_state)
+
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
+ self.register_buffer("mask", mask, persistent=False)
+
+ self.use_padmask = kwargs.get("use_padmask", True)
+ # def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
+ # """
+ # x : torch.LongTensor, shape = (batch_size, <= n_ctx)
+ # the text tokens
+ # xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
+ # the encoded audio features to be attended on
+ # """
+ # offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+ # x = (
+ # self.token_embedding(x)
+ # + self.positional_embedding[offset: offset + x.shape[-1]]
+ # )
+ # x = x.to(xa.dtype)
+ #
+ # for block in self.blocks:
+ # x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
+ #
+ # x = self.ln(x)
+ # logits = (
+ # x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+ # ).float()
+ #
+ # return logits
+
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ xa: torch.Tensor,
+ kv_cache: Optional[dict] = None,
+ **kwargs,
+ ):
+ """Forward decoder.
+
+ Args:
+ hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
+ hlens: (batch)
+ ys_in_pad:
+ input token ids, int64 (batch, maxlen_out)
+ if input_layer == "embed"
+ input tensor (batch, maxlen_out, #mels) in the other cases
+ ys_in_lens: (batch)
+ Returns:
+ (tuple): tuple containing:
+
+ x: decoded token score before softmax (batch, maxlen_out, token)
+ if use_output_layer is True,
+ olens: (batch, )
+ """
+ # import pdb;pdb.set_trace()
+ use_padmask = self.use_padmask
+ hlens = kwargs.get("hlens", None)
+
+ ys_in_lens = kwargs.get("ys_in_lens", None)
+
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+ tgt, memory = x, xa
+ tgt[tgt == -1] = 0
+ tgt = (
+ self.token_embedding(tgt)
+ + self.positional_embedding[offset: offset + tgt.size(1)]
+ )
+ # tgt = self.dropout(tgt)
+
+ x = tgt.to(memory.dtype)
+
+ if use_padmask and hlens is not None:
+ memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
+ else:
+ memory_mask = None
+
+ for layer, block in enumerate(self.blocks):
+ x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
+
+ x = self.ln(x)
+ x = (
+ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+ ).float()
+
+ return x
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index b5272a1..fa1c047 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -226,4 +226,213 @@
results.append(result_i)
return results, meta_data
-
\ No newline at end of file
+
+
+@tables.register("model_classes", "SenseVoiceRWKV")
+class SenseVoiceRWKV(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ dims = kwargs.get("dims", {})
+ dims = whisper.model.ModelDimensions(**dims)
+ model = whisper.model.Whisper(dims=dims)
+
+ # encoder
+ model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
+ model.encoder.use_padmask = kwargs.get("use_padmask", True)
+ from .encoder import sense_voice_encode_forward
+ model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
+
+ # decoder
+ del model.decoder
+ decoder = kwargs.get("decoder", "SenseVoiceDecoder")
+ decoder_class = tables.decoder_classes.get(decoder)
+ decoder = decoder_class(n_vocab=dims.n_vocab,
+ n_ctx=dims.n_text_ctx,
+ n_state=dims.n_text_state,
+ n_head=dims.n_text_head,
+ n_layer=dims.n_text_layer,
+ **kwargs.get("decoder_conf"))
+ model.decoder = decoder
+
+ self.model = model
+
+ self.encoder_output_size = self.model.dims.n_audio_state
+
+ self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
+ self.ignore_id = kwargs.get("ignore_id", -1)
+ self.vocab_size = kwargs.get("vocab_size", -1)
+ self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
+ self.criterion_att = LabelSmoothingLoss(
+ size=self.vocab_size,
+ padding_idx=self.ignore_id,
+ smoothing=kwargs.get("lsm_weight", 0.0),
+ normalize_length=self.length_normalized_loss,
+ )
+
+ specaug = kwargs.get("specaug", None)
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug)
+ specaug = specaug_class(**kwargs.get("specaug_conf", {}))
+ self.specaug = specaug
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ):
+ target_mask = kwargs.get("target_mask", None)
+
+ # import pdb;
+ # pdb.set_trace()
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ if self.activation_checkpoint:
+ from torch.utils.checkpoint import checkpoint
+ encoder_out, encoder_out_lens = checkpoint(self.encode, speech, speech_lengths, use_reentrant=False)
+ else:
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
+ )
+ loss = loss_att
+ stats = {}
+ stats["acc"] = acc_att
+ stats["loss"] = torch.clone(loss.detach())
+ stats["batch_size"] = batch_size
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + 1).sum())
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ):
+ """Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Forward encoder
+ encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
+
+ return encoder_out, encoder_out_lens
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ **kwargs,
+ ):
+ target_mask = kwargs.get("target_mask", None)
+ stats = {}
+
+ # 1. Forward decoder
+ decoder_out = self.model.decoder(
+ x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
+ )
+
+ # 2. Compute attention loss
+ mask = torch.ones_like(ys_pad) * (-1)
+ ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
+ ys_pad_mask[ys_pad_mask == 0] = -1
+ loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
+
+ with torch.no_grad():
+ preds = torch.argmax(decoder_out, -1)
+ acc_att = compute_accuracy(preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id)
+
+ return loss_att, acc_att, None, None
+
+ def inference(self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+ if frontend is None and not hasattr(self, "frontend"):
+ frontend_class = tables.frontend_classes.get("WhisperFrontend")
+ frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
+ self.frontend = frontend
+ else:
+ frontend = frontend if frontend is not None else self.frontend
+
+ meta_data = {}
+ if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in,
+ fs=frontend.fs if hasattr(frontend, "fs") else 16000,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
+ lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])[0, :, :]
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+ DecodingOptions = kwargs.get("DecodingOptions", {})
+ task = DecodingOptions.get("task", "ASR")
+ if isinstance(task, str):
+ task = [task]
+ task = "".join([f"<|{x}|>" for x in task])
+ initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
+ DecodingOptions["initial_prompt"] = initial_prompt
+
+ language = DecodingOptions.get("language", None)
+ language = None if language == "auto" else language
+ DecodingOptions["language"] = language
+
+ DecodingOptions["vocab_path"] = kwargs.get("vocab_path", None)
+
+ if "without_timestamps" not in DecodingOptions:
+ DecodingOptions["without_timestamps"] = True
+
+ options = whisper.DecodingOptions(**DecodingOptions)
+
+ result = whisper.decode(self.model, speech, options)
+ text = f"{result.text}"
+ results = []
+ result_i = {"key": key[0], "text": text}
+
+ results.append(result_i)
+
+ return results, meta_data
diff --git a/funasr/models/sense_voice/rwkv_v6.py b/funasr/models/sense_voice/rwkv_v6.py
new file mode 100644
index 0000000..6eb53fc
--- /dev/null
+++ b/funasr/models/sense_voice/rwkv_v6.py
@@ -0,0 +1,425 @@
+########################################################################################################
+# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
+########################################################################################################
+
+import os, math, gc, importlib
+import torch
+# torch._C._jit_set_profiling_executor(True)
+# torch._C._jit_set_profiling_mode(True)
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+
+def __nop(ob):
+ return ob
+
+
+MyModule = nn.Module
+MyFunction = __nop
+if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
+ MyModule = torch.jit.ScriptModule
+ MyFunction = torch.jit.script_method
+
+########################################################################################################
+# CUDA Kernel
+########################################################################################################
+
+wkv6_cuda = None
+
+def load_rwkv_kernel(HEAD_SIZE: int=64, RWKV_CTXLEN: int=512,):
+ from torch.utils.cpp_extension import load
+ global wkv6_cuda
+
+
+ if wkv6_cuda is not None:
+ return
+
+ absolute_file_path = os.path.abspath(__file__)
+ cur_dir = os.path.dirname(absolute_file_path)
+ wkv6_cuda = load(name="wkv6", sources=[f"{cur_dir}/cuda/wkv6_op.cpp", f"{cur_dir}/cuda/wkv6_cuda.cu"],
+ verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3",
+ "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}",
+ f"-D_T_={RWKV_CTXLEN}"])
+
+# dtype = torch.float
+dtype = torch.bfloat16
+class WKV_6(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, B, T, C, H, r, k, v, w, u):
+ with torch.no_grad():
+ # assert r.dtype == torch.bfloat16
+ # assert k.dtype == torch.bfloat16
+ # assert v.dtype == torch.bfloat16
+ # assert w.dtype == torch.bfloat16
+ # assert u.dtype == torch.bfloat16
+ # assert HEAD_SIZE == C // H
+ ctx.B = B
+ ctx.T = T
+ ctx.C = C
+ ctx.H = H
+ assert r.is_contiguous()
+ assert k.is_contiguous()
+ assert v.is_contiguous()
+ assert w.is_contiguous()
+ assert u.is_contiguous()
+ ew = (-torch.exp(w.float())).contiguous()
+ ctx.save_for_backward(r, k, v, ew, u)
+ y = torch.empty((B, T, C), device=r.device, dtype=dtype,
+ memory_format=torch.contiguous_format) # .uniform_(-100, 100)
+ wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y)
+ return y
+
+ @staticmethod
+ def backward(ctx, gy):
+ with torch.no_grad():
+ # assert gy.dtype == torch.bfloat16
+ B = ctx.B
+ T = ctx.T
+ C = ctx.C
+ H = ctx.H
+ assert gy.is_contiguous()
+ r, k, v, ew, u = ctx.saved_tensors
+ gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
+ memory_format=torch.contiguous_format) # .uniform_(-100, 100)
+ gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
+ memory_format=torch.contiguous_format) # .uniform_(-100, 100)
+ gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
+ memory_format=torch.contiguous_format) # .uniform_(-100, 100)
+ gw = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype,
+ memory_format=torch.contiguous_format) # .uniform_(-100, 100)
+ gu = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=dtype,
+ memory_format=torch.contiguous_format) # .uniform_(-100, 100)
+ wkv6_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu)
+ gu = torch.sum(gu, 0).view(H, C // H)
+ return (None, None, None, None, gr, gk, gv, gw, gu)
+
+
+def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u):
+ return WKV_6.apply(B, T, C, H, r, k, v, w, u)
+
+
+class RWKV_Tmix_x060(MyModule):
+ def __init__(self, args, layer_id):
+ super().__init__()
+ self.args = args
+
+ load_rwkv_kernel(args.head_size_a, args.ctx_len)
+
+ self.layer_id = layer_id
+
+ self.head_size = args.head_size_a
+ self.n_head = args.dim_att // self.head_size
+ assert args.dim_att % self.n_head == 0
+
+ with torch.no_grad():
+ ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
+ ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
+ ddd = torch.ones(1, 1, args.n_embd)
+ for i in range(args.n_embd):
+ ddd[0, 0, i] = i / args.n_embd
+
+ # fancy time_mix
+ self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+ self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+ self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+ self.time_maa_v = nn.Parameter(1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1))
+ self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+ self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
+
+ D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g
+ self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
+ self.time_maa_w2 = nn.Parameter(torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01))
+
+ # fancy time_decay
+ decay_speed = torch.ones(args.dim_att)
+ for n in range(args.dim_att):
+ decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+ self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
+
+ D_DECAY_LORA = 64
+ self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
+ self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01))
+
+ tmp = torch.zeros(args.dim_att)
+ for n in range(args.dim_att):
+ zigzag = ((n + 1) % 3 - 1) * 0.1
+ tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
+
+ self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
+
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+ self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
+ self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
+
+ self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
+ self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
+ self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
+ self.ln_x = nn.GroupNorm(self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor ** 2))
+
+ @MyFunction
+ def jit_func(self, x):
+ B, T, C = x.size()
+
+ xx = self.time_shift(x) - x
+
+ xxx = x + xx * self.time_maa_x
+ xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
+ xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
+ mw, mk, mv, mr, mg = xxx.unbind(dim=0)
+
+ xw = x + xx * (self.time_maa_w + mw)
+ xk = x + xx * (self.time_maa_k + mk)
+ xv = x + xx * (self.time_maa_v + mv)
+ xr = x + xx * (self.time_maa_r + mr)
+ xg = x + xx * (self.time_maa_g + mg)
+
+ r = self.receptance(xr)
+ k = self.key(xk)
+ v = self.value(xv)
+ g = F.silu(self.gate(xg))
+
+ ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
+ w = self.time_decay + ww
+
+ return r, k, v, g, w
+
+ @MyFunction
+ def jit_func_2(self, x, g):
+ B, T, C = x.size()
+ x = x.view(B * T, C)
+
+ x = self.ln_x(x).view(B, T, C)
+ x = self.output(x * g)
+ return x
+
+ def forward(self, x):
+ B, T, C = x.size()
+ H = self.n_head
+
+ r, k, v, g, w = self.jit_func(x)
+ x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
+
+ return self.jit_func_2(x, g)
+
+class RWKV_CMix_x060(MyModule):
+ def __init__(self, args, layer_id):
+ super().__init__()
+ self.args = args
+ self.layer_id = layer_id
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
+
+ with torch.no_grad(): # fancy init of time_mix
+ ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
+ ddd = torch.ones(1, 1, args.n_embd)
+ for i in range(args.n_embd):
+ ddd[0, 0, i] = i / args.n_embd
+ self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+ self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
+
+ self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
+ self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
+ self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
+
+ @MyFunction
+ def forward(self, x):
+ xx = self.time_shift(x) - x
+ xk = x + xx * self.time_maa_k
+ xr = x + xx * self.time_maa_r
+
+ k = self.key(xk)
+ k = torch.relu(k) ** 2
+ kv = self.value(k)
+ return torch.sigmoid(self.receptance(xr)) * kv
+
+
+class Block(nn.Module):
+ def __init__(self, args, layer_id):
+ super().__init__()
+ self.args = args
+ self.layer_id = layer_id
+
+ self.ln1 = nn.LayerNorm(args.n_embd)
+ self.ln2 = nn.LayerNorm(args.n_embd)
+
+ if self.layer_id == 0:
+ self.ln0 = nn.LayerNorm(args.n_embd)
+
+
+ self.att = RWKV_Tmix_x060(args, layer_id)
+
+ self.ffn = RWKV_CMix_x060(args, layer_id)
+
+
+ if args.dropout > 0:
+ self.drop0 = nn.Dropout(p=args.dropout)
+ self.drop1 = nn.Dropout(p=args.dropout)
+
+ def forward(self, x, x_emb=None):
+ args = self.args
+ B, T, C = x.size()
+ if self.layer_id == 0:
+ x = self.ln0(x)
+
+
+ if self.args.dropout == 0:
+ if self.layer_id == 0 and args.pre_ffn > 0:
+ x = x + self.ffnPre(self.ln1(x))
+ else:
+ x = x + self.att(self.ln1(x))
+ x = x + self.ffn(self.ln2(x))
+ else:
+ if self.layer_id == 0 and args.pre_ffn > 0:
+ x = self.drop0(x + self.ffnPre(self.ln1(x)))
+ else:
+ x = self.drop0(x + self.att(self.ln1(x)))
+ x = self.drop1(x + self.ffn(self.ln2(x)))
+
+ return x
+
+
+class RWKVLayer(nn.Module):
+ def __init__(self, args, layer_id):
+ super().__init__()
+ self.args = args
+ self.layer_id = layer_id
+ if args.dim_ffn is None:
+ args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
+ self.ln0 = None
+ if self.layer_id == 0 and args.get("ln0", True):
+ self.ln0 = nn.LayerNorm(args.n_embd)
+
+ self.ln1 = None
+ if args.get("ln1", True):
+ self.ln1 = nn.LayerNorm(args.n_embd)
+ self.ln2 = nn.LayerNorm(args.n_embd)
+
+
+ self.att = RWKV_Tmix_x060(args, layer_id)
+
+ self.ffn = RWKV_CMix_x060(args, layer_id)
+
+ if args.dropout > 0:
+ self.drop0 = nn.Dropout(p=args.dropout)
+ self.drop1 = nn.Dropout(p=args.dropout)
+
+ # init
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ nn.init.orthogonal_(self.att.receptance.weight, gain=1)
+ nn.init.orthogonal_(self.att.key.weight, gain=0.1)
+ nn.init.orthogonal_(self.att.value.weight, gain=1)
+ nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
+ nn.init.zeros_(self.att.output.weight)
+
+ nn.init.orthogonal_(self.ffn.key.weight, gain=1)
+ nn.init.zeros_(self.ffn.value.weight)
+ nn.init.zeros_(self.ffn.receptance.weight)
+ scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+ nn.init.constant_(self.ln2.weight, scale)
+ if self.ln0 is not None:
+ nn.init.constant_(self.ln0.weight, scale)
+ if self.ln1 is not None:
+ nn.init.constant_(self.ln1.weight, scale)
+
+ def forward(self, x, x_emb=None, mask=None, **kwargs):
+
+ args = self.args
+ if args.get("datatype", "bf16") == "bf16":
+ x = x.bfloat16()
+ B, T, C = x.size()
+ if self.layer_id == 0 and self.ln0 is not None:
+ x = self.ln0(x)
+
+ if self.args.dropout == 0:
+ if self.ln1 is None:
+ x = x + self.att(x)
+ else:
+ x = x + self.att(self.ln1(x))
+ x = x + self.ffn(self.ln2(x))
+ else:
+ if self.ln1 is None:
+ x = self.drop0(x + self.att(x))
+ else:
+ x = self.drop0(x + self.att(self.ln1(x)))
+ x = self.drop1(x + self.ffn(self.ln2(x)))
+
+ if args.get("datatype", "bf16") == "bf16":
+ x = x.to(torch.float32)
+ return x
+
+
+class RWKV(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.args = args
+ if not hasattr(args, 'dim_att'):
+ args.dim_att = args.n_embd
+ if not hasattr(args, 'dim_ffn'):
+ if '-f4' in os.environ["RWKV_MY_TESTING"]:
+ args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
+ else:
+ args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
+ if not hasattr(args, 'tiny_att_layer'):
+ args.tiny_att_layer = -1
+ if not hasattr(args, 'tiny_att_dim'):
+ args.tiny_att_dim = -1
+ assert args.n_embd % 32 == 0
+ assert args.dim_att % 32 == 0
+ assert args.dim_ffn % 32 == 0
+
+ self.emb = nn.Embedding(args.vocab_size, args.n_embd)
+
+ self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
+
+ self.ln_out = nn.LayerNorm(args.n_embd)
+ self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
+
+
+ if args.dropout > 0:
+ self.drop0 = nn.Dropout(p=args.dropout)
+
+
+ def forward(self, idx):
+ args = self.args
+ B, T = idx.size()
+ assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
+
+ x = self.emb(idx)
+ x_emb = x
+
+ if args.dropout > 0:
+ x = self.drop0(x)
+ if args.tiny_att_dim > 0:
+ for block in self.blocks:
+ if args.grad_cp == 1:
+ x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
+ else:
+ x = block(x, x_emb)
+ else:
+ for block in self.blocks:
+ if args.grad_cp == 1:
+ x = deepspeed.checkpointing.checkpoint(block, x)
+ else:
+ x = block(x)
+
+ x = self.ln_out(x)
+
+ if args.head_qk > 0:
+ q = self.head_q(x)[:, :T, :]
+ k = self.head_k(x)[:, :T, :]
+ c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
+ c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
+
+ if "32" in os.environ["RWKV_FLOAT_MODE"]:
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size)
+ elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
+ elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
+ c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
+
+ x = self.head(x) + c
+ else:
+ x = self.head(x)
+
+ return x
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index ca960f1..5f7caeb 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -261,7 +261,9 @@
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
- self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+ # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+ # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense()
+ # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(
--
Gitblit v1.9.1