Dev gzf exp (#1645)
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* sensevoice finetune
* bugfix
* update with main (#1631)
* update seaco finetune
* v1.0.24
---------
Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
* sensevoice
* sensevoice
* sensevoice
* update with main (#1638)
* update seaco finetune
* v1.0.24
* update rwkv template
---------
Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sensevoice
* sense voice
* sense voice
* sense voice
* sense voice
---------
Co-authored-by: 维石 <shixian.shi@alibaba-inc.com>
| New file |
| | |
| | | # This is an example that demonstrates how to configure a model file. |
| | | # You can modify the configuration according to your own requirements. |
| | | |
| | | # to print the register_table: |
| | | # from funasr.register import tables |
| | | # tables.print() |
| | | |
| | | # network architecture |
| | | model: Conformer |
| | | model_conf: |
| | | ctc_weight: 0.3 |
| | | lsm_weight: 0.1 # label smoothing option |
| | | length_normalized_loss: false |
| | | |
| | | # encoder |
| | | encoder: ConformerEncoder |
| | | encoder_conf: |
| | | output_size: 256 # dimension of attention |
| | | attention_heads: 4 |
| | | linear_units: 2048 # the number of units of position-wise feed forward |
| | | num_blocks: 12 # the number of encoder blocks |
| | | dropout_rate: 0.1 |
| | | positional_dropout_rate: 0.1 |
| | | attention_dropout_rate: 0.0 |
| | | input_layer: conv2d # encoder architecture type |
| | | normalize_before: true |
| | | pos_enc_layer_type: rel_pos |
| | | selfattention_layer_type: rel_selfattn |
| | | activation_type: swish |
| | | macaron_style: true |
| | | use_cnn_module: true |
| | | cnn_module_kernel: 15 |
| | | |
| | | # decoder |
| | | decoder: TransformerRWKVDecoder |
| | | decoder_conf: |
| | | attention_heads: 4 |
| | | linear_units: 2048 |
| | | num_blocks: 6 |
| | | dropout_rate: 0.1 |
| | | positional_dropout_rate: 0.1 |
| | | self_attention_dropout_rate: 0.0 |
| | | src_attention_dropout_rate: 0.0 |
| | | input_layer: embed |
| | | rwkv_cfg: |
| | | n_embd: 256 |
| | | dropout: 0 |
| | | head_size_a: 64 |
| | | ctx_len: 512 |
| | | dim_att: 256 #${model_conf.rwkv_cfg.n_embd} |
| | | dim_ffn: null |
| | | head_size_divisor: 4 |
| | | n_layer: 6 |
| | | pre_ffn: 0 |
| | | ln0: false |
| | | ln1: false |
| | | init_rwkv: true |
| | | |
| | | |
| | | # frontend related |
| | | frontend: WavFrontend |
| | | frontend_conf: |
| | | fs: 16000 |
| | | window: hamming |
| | | n_mels: 80 |
| | | frame_length: 25 |
| | | frame_shift: 10 |
| | | lfr_m: 1 |
| | | lfr_n: 1 |
| | | |
| | | specaug: SpecAug |
| | | specaug_conf: |
| | | apply_time_warp: true |
| | | time_warp_window: 5 |
| | | time_warp_mode: bicubic |
| | | apply_freq_mask: true |
| | | freq_mask_width_range: |
| | | - 0 |
| | | - 30 |
| | | num_freq_mask: 2 |
| | | apply_time_mask: true |
| | | time_mask_width_range: |
| | | - 0 |
| | | - 40 |
| | | num_time_mask: 2 |
| | | |
| | | train_conf: |
| | | accum_grad: 1 |
| | | grad_clip: 5 |
| | | max_epoch: 150 |
| | | keep_nbest_models: 10 |
| | | log_interval: 50 |
| | | |
| | | optim: adam |
| | | optim_conf: |
| | | lr: 0.0005 |
| | | scheduler: warmuplr |
| | | scheduler_conf: |
| | | warmup_steps: 30000 |
| | | |
| | | dataset: AudioDataset |
| | | dataset_conf: |
| | | index_ds: IndexDSJsonl |
| | | batch_sampler: EspnetStyleBatchSampler |
| | | batch_type: length # example or length |
| | | batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; |
| | | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, |
| | | buffer_size: 1024 |
| | | shuffle: True |
| | | num_workers: 4 |
| | | preprocessor_speech: SpeechPreprocessSpeedPerturb |
| | | preprocessor_speech_conf: |
| | | speed_perturb: [0.9, 1.0, 1.1] |
| | | |
| | | tokenizer: CharTokenizer |
| | | tokenizer_conf: |
| | | unk_symbol: <unk> |
| | | |
| | | ctc_conf: |
| | | dropout_rate: 0.0 |
| | | ctc_type: builtin |
| | | reduce: true |
| | | ignore_nan_grad: true |
| | | normalize: null |
| | |
| | | # freeze_param |
| | | freeze_param = kwargs.get("freeze_param", None) |
| | | if freeze_param is not None: |
| | | freeze_param = eval(freeze_param) |
| | | if "," in freeze_param: |
| | | freeze_param = eval(freeze_param) |
| | | if isinstance(freeze_param, Sequence): |
| | | freeze_param = (freeze_param,) |
| | | logging.info("freeze_param is not None: %s", freeze_param) |
| | |
| | | if use_ddp: |
| | | model = model.cuda(local_rank) |
| | | model = DDP(model, device_ids=[local_rank], |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True)) |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False)) |
| | | elif use_fsdp: |
| | | # model = FSDP(model).cuda(local_rank) |
| | | |
| | |
| | | def __init__(self, dataset, |
| | | batch_size, |
| | | batch_type="token", |
| | | num_replicas=None, |
| | | rank=None, |
| | | num_replicas=None, |
| | | rank_split=False, |
| | | shuffle=True, |
| | | drop_last=False, |
| | | is_training: bool = True, |
| | |
| | | rank = dist.get_rank() |
| | | num_replicas = dist.get_world_size() |
| | | except: |
| | | rank = 0 |
| | | num_replicas = 1 |
| | | if rank_split: |
| | | logging.info(f"Warning, rank_split: {rank_split}, batch and shuffle data in local rank") |
| | | rank = 0 |
| | | num_replicas = 1 |
| | | self.rank = rank |
| | |
| | | self.length_scale_source = kwargs.get("length_scale_source", 1.0) |
| | | |
| | | |
| | | super().__init__(dataset, num_replicas=num_replicas, rank=rank, |
| | | shuffle=shuffle, drop_last=drop_last) |
| | | # super().__init__(dataset, num_replicas=num_replicas, rank=rank, |
| | | # shuffle=shuffle, drop_last=drop_last) |
| | | def __iter__(self): |
| | | if self.shuffle: |
| | | g = torch.Generator() |
| | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("index_ds_classes", "IndexDSJsonlRankSplit") |
| | | class IndexDSJsonlRankSplit(torch.utils.data.Dataset): |
| | | |
| | | def __init__(self, path): |
| | | super().__init__() |
| | | |
| | | contents = [] |
| | | with open(path, encoding='utf-8') as fin: |
| | | for line in fin: |
| | | data = json.loads(line.strip()) |
| | | if "text" in data: # for sft |
| | | self.contents.append(data['text']) |
| | | if "source" in data: # for speech lab pretrain |
| | | prompt = data["prompt"] |
| | | source = data["source"] |
| | | target = data["target"] |
| | | source_len = data["source_len"] |
| | | target_len = data["target_len"] |
| | | |
| | | contents.append({"source": source, |
| | | "prompt": prompt, |
| | | "target": target, |
| | | "source_len": source_len, |
| | | "target_len": target_len, |
| | | } |
| | | ) |
| | | |
| | | self.contents = [] |
| | | total_num = len(contents) |
| | | try: |
| | | rank = dist.get_rank() |
| | | world_size = dist.get_world_size() |
| | | except: |
| | | rank = 0 |
| | | world_size = 1 |
| | | logging.warning("distributed is not initialized, only single shard") |
| | | num_per_rank = total_num // world_size |
| | | |
| | | # rank = 0 |
| | | # import ipdb; ipdb.set_trace() |
| | | self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank] |
| | | |
| | | logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents))) |
| | | |
| | | def __len__(self): |
| | | return len(self.contents) |
| | | |
| | | def __getitem__(self, index): |
| | | try: |
| | | data = self.contents[index] |
| | | except: |
| | | print(index) |
| | | return data |
| | | |
| | | def get_source_len(self, data_dict): |
| | | return data_dict["source_len"] |
| | | |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict["target_len"] if "target_len" in data_dict else 0 |
| | | # @tables.register("index_ds_classes", "IndexDSJsonlRankSplit") |
| | | # class IndexDSJsonlRankSplit(torch.utils.data.Dataset): |
| | | # |
| | | # def __init__(self, path): |
| | | # super().__init__() |
| | | # |
| | | # contents = [] |
| | | # with open(path, encoding='utf-8') as fin: |
| | | # for line in fin: |
| | | # data = json.loads(line.strip()) |
| | | # if "text" in data: # for sft |
| | | # self.contents.append(data['text']) |
| | | # if "source" in data: # for speech lab pretrain |
| | | # prompt = data["prompt"] |
| | | # source = data["source"] |
| | | # target = data["target"] |
| | | # source_len = data["source_len"] |
| | | # target_len = data["target_len"] |
| | | # |
| | | # contents.append({"source": source, |
| | | # "prompt": prompt, |
| | | # "target": target, |
| | | # "source_len": source_len, |
| | | # "target_len": target_len, |
| | | # } |
| | | # ) |
| | | # |
| | | # self.contents = [] |
| | | # total_num = len(contents) |
| | | # try: |
| | | # rank = dist.get_rank() |
| | | # world_size = dist.get_world_size() |
| | | # except: |
| | | # rank = 0 |
| | | # world_size = 1 |
| | | # logging.warning("distributed is not initialized, only single shard") |
| | | # num_per_rank = total_num // world_size |
| | | # |
| | | # # rank = 0 |
| | | # # import ipdb; ipdb.set_trace() |
| | | # self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank] |
| | | # |
| | | # logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents))) |
| | | # |
| | | # def __len__(self): |
| | | # return len(self.contents) |
| | | # |
| | | # def __getitem__(self, index): |
| | | # try: |
| | | # data = self.contents[index] |
| | | # except: |
| | | # print(index) |
| | | # return data |
| | | # |
| | | # def get_source_len(self, data_dict): |
| | | # return data_dict["source_len"] |
| | | # |
| | | # def get_target_len(self, data_dict): |
| | | # |
| | | # return data_dict["target_len"] if "target_len" in data_dict else 0 |
| | | |
| | | @tables.register("index_ds_classes", "IndexDSJsonl") |
| | | @tables.register("index_ds_classes", "IndexDSJsonlRankFull") |
| | |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict.get("target_len", 0) |
| | | |
| | | |
| | | @tables.register("index_ds_classes", "IndexDSJsonlRankSplit") |
| | | class IndexDSJsonlRankSplit(torch.utils.data.Dataset): |
| | | |
| | | def __init__(self, path: str, **kwargs): |
| | | super().__init__() |
| | | self.max_source_length = kwargs.get("max_source_length", 2048) |
| | | self.min_source_length = kwargs.get("min_source_length", 0) |
| | | self.max_target_length = kwargs.get("max_target_length", 2048) |
| | | self.min_target_length = kwargs.get("min_target_length", 0) |
| | | |
| | | with open(path, encoding='utf-8') as fin: |
| | | file_list = fin.readlines() |
| | | |
| | | total_num = len(file_list) |
| | | try: |
| | | rank = dist.get_rank() |
| | | world_size = dist.get_world_size() |
| | | except: |
| | | rank = 0 |
| | | world_size = 1 |
| | | logging.warning("distributed is not initialized, only single shard") |
| | | num_per_rank = total_num // world_size |
| | | if num_per_rank * world_size < total_num: |
| | | logging.warning(f"Warning, jsonl file:{total_num} could not be divided by world_size: {world_size}, {path}") |
| | | |
| | | file_list_rank = file_list[rank * num_per_rank:(rank + 1) * num_per_rank] |
| | | |
| | | contents = [] |
| | | for file_json in file_list_rank: |
| | | |
| | | with open(file_json.strip(), encoding='utf-8') as fin: |
| | | for line in fin: |
| | | data = json.loads(line.strip()) |
| | | if "text" in data: # for sft |
| | | contents.append(data['text']) |
| | | if "source" in data: # for speech lab pretrain |
| | | prompt = data.get("prompt", "<ASR>") |
| | | source = data["source"].replace("/cpfs01", "/cpfs_speech/data") |
| | | target = data["target"] |
| | | source_len = data.get("source_len", 1) |
| | | target_len = data.get("target_len", 0) |
| | | |
| | | if source_len < self.min_source_length or source_len > self.max_source_length: |
| | | continue |
| | | if target_len < self.min_target_length or target_len > self.max_target_length: |
| | | continue |
| | | contents_i = {"source": source, |
| | | "prompt": prompt, |
| | | "target": target, |
| | | "source_len": source_len, |
| | | "target_len": target_len, |
| | | } |
| | | text_language = data.get("text_language", None) |
| | | if text_language is not None: |
| | | contents_i["text_language"] = text_language |
| | | audio_language = data.get("audio_language", None) |
| | | if audio_language is not None: |
| | | contents_i["audio_language"] = audio_language |
| | | contents.append(contents_i) |
| | | |
| | | self.contents = contents |
| | | |
| | | logging.info(f"total_num: {len(self.contents)} of samplers in ranks: {rank}") |
| | | |
| | | def __len__(self): |
| | | return len(self.contents) |
| | | |
| | | def __getitem__(self, index): |
| | | try: |
| | | data = self.contents[index] |
| | | except: |
| | | print(index) |
| | | return data |
| | | |
| | | def get_source_len(self, data_dict): |
| | | return data_dict.get("source_len", 1) |
| | | |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict.get("target_len", 0) |
| New file |
| | |
| | | # Copyright 2019 Shigeki Karita |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | """Decoder definition.""" |
| | | from typing import Any |
| | | from typing import List |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | from torch import nn |
| | | |
| | | |
| | | from funasr.models.transformer.attention import MultiHeadedAttention |
| | | from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution |
| | | from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D |
| | | from funasr.models.transformer.embedding import PositionalEncoding |
| | | from funasr.models.transformer.layer_norm import LayerNorm |
| | | from funasr.models.transformer.utils.lightconv import LightweightConvolution |
| | | from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D |
| | | from funasr.models.transformer.utils.mask import subsequent_mask |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.transformer.positionwise_feed_forward import ( |
| | | PositionwiseFeedForward, # noqa: H301 |
| | | ) |
| | | from funasr.models.transformer.utils.repeat import repeat |
| | | from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface |
| | | from omegaconf import OmegaConf |
| | | from funasr.register import tables |
| | | |
| | | class DecoderLayer(nn.Module): |
| | | """Single decoder layer module. |
| | | |
| | | Args: |
| | | size (int): Input dimension. |
| | | self_attn (torch.nn.Module): Self-attention module instance. |
| | | `MultiHeadedAttention` instance can be used as the argument. |
| | | src_attn (torch.nn.Module): Self-attention module instance. |
| | | `MultiHeadedAttention` instance can be used as the argument. |
| | | feed_forward (torch.nn.Module): Feed-forward module instance. |
| | | `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance |
| | | can be used as the argument. |
| | | dropout_rate (float): Dropout rate. |
| | | normalize_before (bool): Whether to use layer_norm before the first block. |
| | | concat_after (bool): Whether to concat attention layer's input and output. |
| | | if True, additional linear will be applied. |
| | | i.e. x -> x + linear(concat(x, att(x))) |
| | | if False, no additional linear will be applied. i.e. x -> x + att(x) |
| | | |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | self_attn, |
| | | src_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | layer_id=None, |
| | | args={}, |
| | | ): |
| | | """Construct an DecoderLayer object.""" |
| | | super(DecoderLayer, self).__init__() |
| | | self.size = size |
| | | self.self_attn = self_attn.to(torch.bfloat16) |
| | | self.src_attn = src_attn |
| | | self.feed_forward = feed_forward |
| | | self.norm1 = LayerNorm(size) |
| | | self.norm2 = LayerNorm(size) |
| | | self.norm3 = LayerNorm(size) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.normalize_before = normalize_before |
| | | self.concat_after = concat_after |
| | | if self.concat_after: |
| | | self.concat_linear1 = nn.Linear(size + size, size) |
| | | self.concat_linear2 = nn.Linear(size + size, size) |
| | | self.layer_id = layer_id |
| | | self.ln0 = None |
| | | if self.layer_id == 0 and not args.get("ln0", True): |
| | | self.ln0 = LayerNorm(args.n_embd) |
| | | if args.get("init_rwkv", True): |
| | | print("init_rwkv") |
| | | layer_id = 0 |
| | | scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7 |
| | | nn.init.constant_(self.ln0.weight, scale) |
| | | |
| | | # init |
| | | if args.get("init_rwkv", True): |
| | | print("init_rwkv") |
| | | scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7 |
| | | nn.init.constant_(self.norm1.weight, scale) |
| | | nn.init.constant_(self.self_attn.ln2.weight, scale) |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): |
| | | """Compute decoded features. |
| | | |
| | | Args: |
| | | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). |
| | | tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). |
| | | memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). |
| | | memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). |
| | | cache (List[torch.Tensor]): List of cached tensors. |
| | | Each tensor shape should be (#batch, maxlen_out - 1, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor(#batch, maxlen_out, size). |
| | | torch.Tensor: Mask for output tensor (#batch, maxlen_out). |
| | | torch.Tensor: Encoded memory (#batch, maxlen_in, size). |
| | | torch.Tensor: Encoded memory mask (#batch, maxlen_in). |
| | | |
| | | """ |
| | | |
| | | if self.layer_id == 0 and self.ln0 is not None: |
| | | tgt = self.ln0(tgt) |
| | | |
| | | residual = tgt |
| | | |
| | | |
| | | tgt = self.norm1(tgt) |
| | | |
| | | if cache is None: |
| | | |
| | | x = residual + self.dropout(self.self_attn(tgt, mask=tgt_mask)) |
| | | else: |
| | | |
| | | # tgt_q = tgt[:, -1:, :] |
| | | # residual_q = residual[:, -1:, :] |
| | | tgt_q_mask = None |
| | | |
| | | x = residual + self.dropout(self.self_attn(tgt, mask=tgt_q_mask)) |
| | | x = x[:, -1, :] |
| | | |
| | | |
| | | |
| | | # x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) |
| | | |
| | | |
| | | residual = x |
| | | x = self.norm2(x) |
| | | x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) |
| | | residual = x |
| | | x = self.norm3(x) |
| | | x = residual + self.dropout(self.feed_forward(x)) |
| | | |
| | | |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | |
| | | return x, tgt_mask, memory, memory_mask |
| | | |
| | | |
| | | |
| | | class BaseTransformerDecoder(nn.Module, BatchScorerInterface): |
| | | """Base class of Transfomer decoder module. |
| | | |
| | | Args: |
| | | vocab_size: output dim |
| | | encoder_output_size: dimension of attention |
| | | attention_heads: the number of heads of multi head attention |
| | | linear_units: the number of units of position-wise feed forward |
| | | num_blocks: the number of decoder blocks |
| | | dropout_rate: dropout rate |
| | | self_attention_dropout_rate: dropout rate for attention |
| | | input_layer: input layer type |
| | | use_output_layer: whether to use output layer |
| | | pos_enc_class: PositionalEncoding or ScaledPositionalEncoding |
| | | normalize_before: whether to use layer_norm before the first block |
| | | concat_after: whether to concat attention layer's input and output |
| | | if True, additional linear will be applied. |
| | | i.e. x -> x + linear(concat(x, att(x))) |
| | | if False, no additional linear will be applied. |
| | | i.e. x -> x + att(x) |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | ): |
| | | super().__init__() |
| | | attention_dim = encoder_output_size |
| | | |
| | | if input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(vocab_size, attention_dim), |
| | | pos_enc_class(attention_dim, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(vocab_size, attention_dim), |
| | | torch.nn.LayerNorm(attention_dim), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | pos_enc_class(attention_dim, positional_dropout_rate), |
| | | ) |
| | | else: |
| | | raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") |
| | | |
| | | self.normalize_before = normalize_before |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(attention_dim) |
| | | if use_output_layer: |
| | | self.output_layer = torch.nn.Linear(attention_dim, vocab_size) |
| | | else: |
| | | self.output_layer = None |
| | | |
| | | # Must set by the inheritance |
| | | self.decoders = None |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward decoder. |
| | | |
| | | Args: |
| | | hs_pad: encoded memory, float32 (batch, maxlen_in, feat) |
| | | hlens: (batch) |
| | | ys_in_pad: |
| | | input token ids, int64 (batch, maxlen_out) |
| | | if input_layer == "embed" |
| | | input tensor (batch, maxlen_out, #mels) in the other cases |
| | | ys_in_lens: (batch) |
| | | Returns: |
| | | (tuple): tuple containing: |
| | | |
| | | x: decoded token score before softmax (batch, maxlen_out, token) |
| | | if use_output_layer is True, |
| | | olens: (batch, ) |
| | | """ |
| | | tgt = ys_in_pad |
| | | # tgt_mask: (B, 1, L) |
| | | tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) |
| | | # m: (1, L, L) |
| | | m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) |
| | | # tgt_mask: (B, L, L) |
| | | tgt_mask = tgt_mask & m |
| | | |
| | | memory = hs_pad |
| | | memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( |
| | | memory.device |
| | | ) |
| | | # Padding for Longformer |
| | | if memory_mask.shape[-1] != memory.shape[1]: |
| | | padlen = memory.shape[1] - memory_mask.shape[-1] |
| | | memory_mask = torch.nn.functional.pad( |
| | | memory_mask, (0, padlen), "constant", False |
| | | ) |
| | | |
| | | x = self.embed(tgt) |
| | | x, tgt_mask, memory, memory_mask = self.decoders( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | if self.output_layer is not None: |
| | | x = self.output_layer(x) |
| | | |
| | | olens = tgt_mask.sum(1) |
| | | return x, olens |
| | | |
| | | def forward_one_step( |
| | | self, |
| | | tgt: torch.Tensor, |
| | | tgt_mask: torch.Tensor, |
| | | memory: torch.Tensor, |
| | | cache: List[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
| | | """Forward one step. |
| | | |
| | | Args: |
| | | tgt: input token ids, int64 (batch, maxlen_out) |
| | | tgt_mask: input token mask, (batch, maxlen_out) |
| | | dtype=torch.uint8 in PyTorch 1.2- |
| | | dtype=torch.bool in PyTorch 1.2+ (include 1.2) |
| | | memory: encoded memory, float32 (batch, maxlen_in, feat) |
| | | cache: cached output list of (batch, max_time_out-1, size) |
| | | Returns: |
| | | y, cache: NN output value and cache per `self.decoders`. |
| | | y.shape` is (batch, maxlen_out, token) |
| | | """ |
| | | x = self.embed(tgt) |
| | | if cache is None: |
| | | cache = [None] * len(self.decoders) |
| | | new_cache = [] |
| | | for c, decoder in zip(cache, self.decoders): |
| | | x, tgt_mask, memory, memory_mask = decoder( |
| | | x, tgt_mask, memory, None, cache=c |
| | | ) |
| | | new_cache.append(x) |
| | | |
| | | if self.normalize_before: |
| | | y = self.after_norm(x[:, -1]) |
| | | else: |
| | | y = x[:, -1] |
| | | if self.output_layer is not None: |
| | | y = torch.log_softmax(self.output_layer(y), dim=-1) |
| | | |
| | | return y, new_cache |
| | | |
| | | def score(self, ys, state, x): |
| | | """Score.""" |
| | | ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) |
| | | logp, state = self.forward_one_step( |
| | | ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state |
| | | ) |
| | | return logp.squeeze(0), state |
| | | |
| | | def batch_score( |
| | | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, List[Any]]: |
| | | """Score new token batch. |
| | | |
| | | Args: |
| | | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). |
| | | states (List[Any]): Scorer states for prefix tokens. |
| | | xs (torch.Tensor): |
| | | The encoder feature that generates ys (n_batch, xlen, n_feat). |
| | | |
| | | Returns: |
| | | tuple[torch.Tensor, List[Any]]: Tuple of |
| | | batchfied scores for next token with shape of `(n_batch, n_vocab)` |
| | | and next state list for ys. |
| | | |
| | | """ |
| | | # merge states |
| | | n_batch = len(ys) |
| | | n_layers = len(self.decoders) |
| | | if states[0] is None: |
| | | batch_state = None |
| | | else: |
| | | # transpose state of [batch, layer] into [layer, batch] |
| | | batch_state = [ |
| | | torch.stack([states[b][i] for b in range(n_batch)]) |
| | | for i in range(n_layers) |
| | | ] |
| | | |
| | | # batch decoding |
| | | ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) |
| | | logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state) |
| | | |
| | | # transpose state of [layer, batch] into [batch, layer] |
| | | state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] |
| | | return logp, state_list |
| | | |
| | | @tables.register("decoder_classes", "TransformerRWKVDecoder") |
| | | class TransformerRWKVDecoder(BaseTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | **kwargs, |
| | | ): |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | dropout_rate=dropout_rate, |
| | | positional_dropout_rate=positional_dropout_rate, |
| | | input_layer=input_layer, |
| | | use_output_layer=use_output_layer, |
| | | pos_enc_class=pos_enc_class, |
| | | normalize_before=normalize_before, |
| | | ) |
| | | from funasr.models.sense_voice.rwkv_v6 import RWKVLayer |
| | | rwkv_cfg = kwargs.get("rwkv_cfg", {}) |
| | | args = OmegaConf.create(rwkv_cfg) |
| | | # self.attn = RWKVLayer(args=args, layer_id=layer_id) |
| | | attention_dim = encoder_output_size |
| | | self.decoders = repeat( |
| | | num_blocks, |
| | | lambda lnum: DecoderLayer( |
| | | attention_dim, |
| | | RWKVLayer(args=args, layer_id=lnum), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | lnum, |
| | | args=args, |
| | | ), |
| | | ) |
| | | |
| | | # init |
| | | if args.get("init_rwkv", True): |
| | | print("init_rwkv") |
| | | nn.init.uniform_(self.embed[0].weight, a=-1e-4, b=1e-4) |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward decoder. |
| | | |
| | | Args: |
| | | hs_pad: encoded memory, float32 (batch, maxlen_in, feat) |
| | | hlens: (batch) |
| | | ys_in_pad: |
| | | input token ids, int64 (batch, maxlen_out) |
| | | if input_layer == "embed" |
| | | input tensor (batch, maxlen_out, #mels) in the other cases |
| | | ys_in_lens: (batch) |
| | | Returns: |
| | | (tuple): tuple containing: |
| | | |
| | | x: decoded token score before softmax (batch, maxlen_out, token) |
| | | if use_output_layer is True, |
| | | olens: (batch, ) |
| | | """ |
| | | tgt = ys_in_pad |
| | | # tgt_mask: (B, 1, L) |
| | | tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) |
| | | # m: (1, L, L) |
| | | m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) |
| | | # tgt_mask: (B, L, L) |
| | | tgt_mask = tgt_mask & m |
| | | |
| | | memory = hs_pad |
| | | memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( |
| | | memory.device |
| | | ) |
| | | # Padding for Longformer |
| | | if memory_mask.shape[-1] != memory.shape[1]: |
| | | padlen = memory.shape[1] - memory_mask.shape[-1] |
| | | memory_mask = torch.nn.functional.pad( |
| | | memory_mask, (0, padlen), "constant", False |
| | | ) |
| | | |
| | | x = self.embed(tgt) |
| | | x, tgt_mask, memory, memory_mask = self.decoders( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | if self.output_layer is not None: |
| | | x = self.output_layer(x) |
| | | |
| | | olens = tgt_mask.sum(1) |
| | | return x, olens |
| | | |
| | | def forward_one_step( |
| | | self, |
| | | tgt: torch.Tensor, |
| | | tgt_mask: torch.Tensor, |
| | | memory: torch.Tensor, |
| | | cache: List[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
| | | """Forward one step. |
| | | |
| | | Args: |
| | | tgt: input token ids, int64 (batch, maxlen_out) |
| | | tgt_mask: input token mask, (batch, maxlen_out) |
| | | dtype=torch.uint8 in PyTorch 1.2- |
| | | dtype=torch.bool in PyTorch 1.2+ (include 1.2) |
| | | memory: encoded memory, float32 (batch, maxlen_in, feat) |
| | | cache: cached output list of (batch, max_time_out-1, size) |
| | | Returns: |
| | | y, cache: NN output value and cache per `self.decoders`. |
| | | y.shape` is (batch, maxlen_out, token) |
| | | """ |
| | | x = self.embed(tgt) |
| | | if cache is None: |
| | | cache = [None] * len(self.decoders) |
| | | new_cache = [] |
| | | for c, decoder in zip(cache, self.decoders): |
| | | x, tgt_mask, memory, memory_mask = decoder( |
| | | x, tgt_mask, memory, None, cache=c |
| | | ) |
| | | new_cache.append(x) |
| | | |
| | | if self.normalize_before: |
| | | y = self.after_norm(x[:, -1]) |
| | | else: |
| | | y = x[:, -1] |
| | | if self.output_layer is not None: |
| | | y = torch.log_softmax(self.output_layer(y), dim=-1) |
| | | |
| | | return y, new_cache |
| New file |
| | |
| | | import logging |
| | | |
| | | import torch |
| | | |
| | | from funasr.models.transformer.model import Transformer |
| | | from funasr.register import tables |
| | | |
| | | @tables.register("model_classes", "Conformer") |
| | | class Conformer(Transformer): |
| | | """CTC-attention hybrid Encoder-Decoder model""" |
| | | |
| | | |
| | | def __init__( |
| | | self, |
| | | *args, |
| | | **kwargs, |
| | | ): |
| | | |
| | | super().__init__(*args, **kwargs) |
| New file |
| | |
| | | # This is an example that demonstrates how to configure a model file. |
| | | # You can modify the configuration according to your own requirements. |
| | | |
| | | # to print the register_table: |
| | | # from funasr.register import tables |
| | | # tables.print() |
| | | |
| | | # network architecture |
| | | model: Conformer |
| | | model_conf: |
| | | ctc_weight: 0.3 |
| | | lsm_weight: 0.1 # label smoothing option |
| | | length_normalized_loss: false |
| | | |
| | | # encoder |
| | | encoder: ConformerEncoder |
| | | encoder_conf: |
| | | output_size: 256 # dimension of attention |
| | | attention_heads: 4 |
| | | linear_units: 2048 # the number of units of position-wise feed forward |
| | | num_blocks: 12 # the number of encoder blocks |
| | | dropout_rate: 0.1 |
| | | positional_dropout_rate: 0.1 |
| | | attention_dropout_rate: 0.0 |
| | | input_layer: conv2d # encoder architecture type |
| | | normalize_before: true |
| | | pos_enc_layer_type: rel_pos |
| | | selfattention_layer_type: rel_selfattn |
| | | activation_type: swish |
| | | macaron_style: true |
| | | use_cnn_module: true |
| | | cnn_module_kernel: 15 |
| | | |
| | | # decoder |
| | | decoder: TransformerRWKVDecoder |
| | | decoder_conf: |
| | | attention_heads: 4 |
| | | linear_units: 2048 |
| | | num_blocks: 6 |
| | | dropout_rate: 0.1 |
| | | positional_dropout_rate: 0.1 |
| | | self_attention_dropout_rate: 0.0 |
| | | src_attention_dropout_rate: 0.0 |
| | | input_layer: embed |
| | | rwkv_cfg: |
| | | n_embd: 256 |
| | | dropout: 0 |
| | | head_size_a: 64 |
| | | ctx_len: 512 |
| | | dim_att: 256 #${model_conf.rwkv_cfg.n_embd} |
| | | dim_ffn: null |
| | | head_size_divisor: 4 |
| | | n_layer: 6 |
| | | pre_ffn: 0 |
| | | ln0: false |
| | | ln1: false |
| | | |
| | | |
| | | # frontend related |
| | | frontend: WavFrontend |
| | | frontend_conf: |
| | | fs: 16000 |
| | | window: hamming |
| | | n_mels: 80 |
| | | frame_length: 25 |
| | | frame_shift: 10 |
| | | lfr_m: 1 |
| | | lfr_n: 1 |
| | | |
| | | specaug: SpecAug |
| | | specaug_conf: |
| | | apply_time_warp: true |
| | | time_warp_window: 5 |
| | | time_warp_mode: bicubic |
| | | apply_freq_mask: true |
| | | freq_mask_width_range: |
| | | - 0 |
| | | - 30 |
| | | num_freq_mask: 2 |
| | | apply_time_mask: true |
| | | time_mask_width_range: |
| | | - 0 |
| | | - 40 |
| | | num_time_mask: 2 |
| | | |
| | | train_conf: |
| | | accum_grad: 1 |
| | | grad_clip: 5 |
| | | max_epoch: 150 |
| | | keep_nbest_models: 10 |
| | | log_interval: 50 |
| | | |
| | | optim: adam |
| | | optim_conf: |
| | | lr: 0.0005 |
| | | scheduler: warmuplr |
| | | scheduler_conf: |
| | | warmup_steps: 30000 |
| | | |
| | | dataset: AudioDataset |
| | | dataset_conf: |
| | | index_ds: IndexDSJsonl |
| | | batch_sampler: EspnetStyleBatchSampler |
| | | batch_type: length # example or length |
| | | batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; |
| | | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, |
| | | buffer_size: 1024 |
| | | shuffle: True |
| | | num_workers: 4 |
| | | preprocessor_speech: SpeechPreprocessSpeedPerturb |
| | | preprocessor_speech_conf: |
| | | speed_perturb: [0.9, 1.0, 1.1] |
| | | |
| | | tokenizer: CharTokenizer |
| | | tokenizer_conf: |
| | | unk_symbol: <unk> |
| | | |
| | | ctc_conf: |
| | | dropout_rate: 0.0 |
| | | ctc_type: builtin |
| | | reduce: true |
| | | ignore_nan_grad: true |
| | | normalize: null |
| New file |
| | |
| | | #include <stdio.h> |
| | | #include <assert.h> |
| | | #include "ATen/ATen.h" |
| | | typedef at::BFloat16 bf16; |
| | | |
| | | template <typename F> |
| | | __global__ void kernel_forward(const int B, const int T, const int C, const int H, |
| | | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, |
| | | F *__restrict__ const _y) |
| | | { |
| | | const int b = blockIdx.x / H; |
| | | const int h = blockIdx.x % H; |
| | | const int i = threadIdx.x; |
| | | _w += h*_N_; |
| | | _u += h*_N_; |
| | | |
| | | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; |
| | | float state[_N_] = {0}; |
| | | |
| | | __syncthreads(); |
| | | w[i] = _w[i]; |
| | | u[i] = float(_u[i]); |
| | | __syncthreads(); |
| | | |
| | | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) |
| | | { |
| | | __syncthreads(); |
| | | r[i] = float(_r[t]); |
| | | k[i] = float(_k[t]); |
| | | __syncthreads(); |
| | | |
| | | const float v = float(_v[t]); |
| | | float y = 0; |
| | | |
| | | #pragma unroll |
| | | for (int j = 0; j < _N_; j+=4) |
| | | { |
| | | const float4& r_ = (float4&)(r[j]); |
| | | const float4& k_ = (float4&)(k[j]); |
| | | const float4& w_ = (float4&)(w[j]); |
| | | const float4& u_ = (float4&)(u[j]); |
| | | float4& s = (float4&)(state[j]); |
| | | float4 x; |
| | | |
| | | x.x = k_.x * v; |
| | | x.y = k_.y * v; |
| | | x.z = k_.z * v; |
| | | x.w = k_.w * v; |
| | | |
| | | y += r_.x * (u_.x * x.x + s.x); |
| | | y += r_.y * (u_.y * x.y + s.y); |
| | | y += r_.z * (u_.z * x.z + s.z); |
| | | y += r_.w * (u_.w * x.w + s.w); |
| | | |
| | | s.x = s.x * w_.x + x.x; |
| | | s.y = s.y * w_.y + x.y; |
| | | s.z = s.z * w_.z + x.z; |
| | | s.w = s.w * w_.w + x.w; |
| | | } |
| | | _y[t] = F(y); |
| | | } |
| | | } |
| | | |
| | | template <typename F> |
| | | __global__ void kernel_backward(const int B, const int T, const int C, const int H, |
| | | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, |
| | | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) |
| | | { |
| | | const int b = blockIdx.x / H; |
| | | const int h = blockIdx.x % H; |
| | | const int i = threadIdx.x; |
| | | _w += h*_N_; |
| | | _u += h*_N_; |
| | | __w += h*_N_; |
| | | |
| | | __shared__ float w_[_N_], u_[_N_]; |
| | | __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_]; |
| | | __syncthreads(); |
| | | w_[i] = _w[i]; |
| | | u_[i] = float(_u[i]); |
| | | __syncthreads(); |
| | | |
| | | const float w = w_[i]; |
| | | const float ww = __w[i]; |
| | | const float u = u_[i]; |
| | | |
| | | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; |
| | | |
| | | float gw = 0, gu = 0; |
| | | const int t000 = b*T*C + h*_N_ + i; |
| | | const int t111 = (b+1)*T*C + h*_N_ + i; |
| | | const int t222 = t111 - 2*C; |
| | | |
| | | for (int t = t000; t < t111; t += C) |
| | | { |
| | | __syncthreads(); |
| | | v[i] = float(_v[t]); |
| | | gy[i] = float(_gy[t]); |
| | | __syncthreads(); |
| | | |
| | | const float k = float(_k[t]); |
| | | float gr = 0, gu_ = 0; |
| | | |
| | | #pragma unroll |
| | | for (int j = 0; j < _N_; j++) |
| | | { |
| | | float& s = state[j]; |
| | | float x = k * v[j]; |
| | | |
| | | gr += (u * x + s) * gy[j]; |
| | | gu_ += x * gy[j]; |
| | | s = s * w + x; |
| | | } |
| | | _gr[t] = F(gr); |
| | | gu += float(_r[t]) * gu_; |
| | | } |
| | | _gu[b*C + h*_N_ + i] = F(gu); |
| | | |
| | | for (int t = t000; t < t222; t += C) |
| | | { |
| | | __syncthreads(); |
| | | v[i] = float(_v[t]); |
| | | gy[i] = float(_gy[t + 2*C]); |
| | | __syncthreads(); |
| | | |
| | | const float k = float(_k[t]); |
| | | float gw_ = 0; |
| | | |
| | | #pragma unroll |
| | | for (int j = 0; j < _N_; j++) |
| | | { |
| | | float& s = saaaa[j]; |
| | | float& s2 = sbbbb[j]; |
| | | float x = k * v[j]; |
| | | |
| | | float tmp = w * (x + s); |
| | | s = tmp; |
| | | s2 = tmp + w * s2; |
| | | gw_ += s2 * gy[j]; |
| | | } |
| | | gw += float(_r[t + 2*C]) * gw_; |
| | | } |
| | | _gw[b*C + h*_N_ + i] = F(ww * gw); |
| | | |
| | | for (int t = t111 - C; t >= t000; t -= C) |
| | | { |
| | | __syncthreads(); |
| | | v[i] = float(_v[t]); |
| | | gy[i] = float(_gy[t]); |
| | | __syncthreads(); |
| | | |
| | | const float rr = float(_r[t]); |
| | | float gk = 0; |
| | | |
| | | #pragma unroll |
| | | for (int j = 0; j < _N_; j++) |
| | | { |
| | | float& s = scccc[j]; |
| | | float x = rr * gy[j]; |
| | | |
| | | gk += (u * x + s) * v[j]; |
| | | s = x + s * w; |
| | | } |
| | | _gk[t] = F(gk); |
| | | } |
| | | |
| | | for (int t = t111 - C; t >= t000; t -= C) |
| | | { |
| | | __syncthreads(); |
| | | r[i] = float(_r[t]); |
| | | k[i] = float(_k[t]); |
| | | __syncthreads(); |
| | | |
| | | const float gyy = float(_gy[t]); |
| | | float gv = 0; |
| | | |
| | | #pragma unroll |
| | | for (int j = 0; j < _N_; j++) |
| | | { |
| | | float& s = sdddd[j]; |
| | | float x = gyy * r[j]; |
| | | |
| | | gv += (u_[j] * x + s) * k[j]; |
| | | s = x + s * w_[j]; |
| | | } |
| | | _gv[t] = F(gv); |
| | | } |
| | | } |
| | | |
| | | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) |
| | | { |
| | | assert(H*_N_ == C); |
| | | assert(_N_%4 == 0); |
| | | kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y); |
| | | } |
| | | |
| | | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) |
| | | { |
| | | assert(H*_N_ == C); |
| | | assert(_N_%4 == 0); |
| | | kernel_backward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); |
| | | } |
| New file |
| | |
| | | #include <torch/extension.h> |
| | | #include "ATen/ATen.h" |
| | | typedef at::BFloat16 bf16; |
| | | |
| | | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); |
| | | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); |
| | | |
| | | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { |
| | | cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>()); |
| | | } |
| | | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { |
| | | cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>()); |
| | | } |
| | | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| | | m.def("forward", &forward, "wkv5 forward"); |
| | | m.def("backward", &backward, "wkv5 backward"); |
| | | } |
| | | |
| | | TORCH_LIBRARY(wkv5, m) { |
| | | m.def("forward", forward); |
| | | m.def("backward", backward); |
| | | } |
| New file |
| | |
| | | #include <stdio.h> |
| | | #include <assert.h> |
| | | #include "ATen/ATen.h" |
| | | typedef at::BFloat16 bf16; |
| | | // typedef float bf16; |
| | | |
| | | template <typename F> |
| | | __global__ void kernel_forward(const int B, const int T, const int C, const int H, |
| | | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, |
| | | F *__restrict__ const _y) |
| | | { |
| | | const int b = blockIdx.x / H; |
| | | const int h = blockIdx.x % H; |
| | | const int i = threadIdx.x; |
| | | _u += h*_N_; |
| | | |
| | | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; |
| | | float state[_N_] = {0}; |
| | | |
| | | __syncthreads(); |
| | | u[i] = float(_u[i]); |
| | | __syncthreads(); |
| | | |
| | | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) |
| | | { |
| | | __syncthreads(); |
| | | w[i] = exp(_w[t]); |
| | | r[i] = float(_r[t]); |
| | | k[i] = float(_k[t]); |
| | | __syncthreads(); |
| | | |
| | | const float v = float(_v[t]); |
| | | float y = 0; |
| | | |
| | | #pragma unroll |
| | | for (int j = 0; j < _N_; j+=4) |
| | | { |
| | | const float4& r_ = (float4&)(r[j]); |
| | | const float4& k_ = (float4&)(k[j]); |
| | | const float4& w_ = (float4&)(w[j]); |
| | | const float4& u_ = (float4&)(u[j]); |
| | | float4& s = (float4&)(state[j]); |
| | | float4 x; |
| | | |
| | | x.x = k_.x * v; |
| | | x.y = k_.y * v; |
| | | x.z = k_.z * v; |
| | | x.w = k_.w * v; |
| | | |
| | | y += r_.x * (u_.x * x.x + s.x); |
| | | y += r_.y * (u_.y * x.y + s.y); |
| | | y += r_.z * (u_.z * x.z + s.z); |
| | | y += r_.w * (u_.w * x.w + s.w); |
| | | |
| | | s.x = s.x * w_.x + x.x; |
| | | s.y = s.y * w_.y + x.y; |
| | | s.z = s.z * w_.z + x.z; |
| | | s.w = s.w * w_.w + x.w; |
| | | } |
| | | _y[t] = F(y); |
| | | } |
| | | } |
| | | |
| | | template <typename F> |
| | | __global__ void kernel_backward_111(const int B, const int T, const int C, const int H, |
| | | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, |
| | | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu) |
| | | { |
| | | const int b = blockIdx.x / H; |
| | | const int h = blockIdx.x % H; |
| | | const int i = threadIdx.x; |
| | | _u += h*_N_; |
| | | |
| | | __shared__ float u_[_N_]; |
| | | __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; |
| | | __syncthreads(); |
| | | u_[i] = float(_u[i]); |
| | | __syncthreads(); |
| | | |
| | | const float u = u_[i]; |
| | | |
| | | float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; |
| | | |
| | | const int t_0 = b*T*C + h*_N_ + i; |
| | | const int t_T_1 = t_0 + (T-1)*C; |
| | | const int t_T = t_0 + T*C; |
| | | |
| | | float gu = 0; |
| | | for (int t = t_0; t < t_T; t += C) |
| | | { |
| | | __syncthreads(); |
| | | v[i] = float(_v[t]); |
| | | gy[i] = float(_gy[t]); |
| | | __syncthreads(); |
| | | |
| | | const float k = float(_k[t]); |
| | | const float w = exp(_w[t]); |
| | | float gr = 0, gu_ = 0; |
| | | |
| | | #pragma unroll |
| | | for (int j = 0; j < _N_; j++) |
| | | { |
| | | float& s = state[j]; |
| | | float x = k * v[j]; |
| | | |
| | | gr += (u * x + s) * gy[j]; |
| | | gu_ += x * gy[j]; |
| | | s = s * w + x; |
| | | } |
| | | _gr[t] = F(gr); |
| | | gu += float(_r[t]) * gu_; |
| | | } |
| | | _gu[b*C + h*_N_ + i] = F(gu); |
| | | |
| | | for (int t = t_T_1; t >= t_0; t -= C) |
| | | { |
| | | __syncthreads(); |
| | | v[i] = float(_v[t]); |
| | | gy[i] = float(_gy[t]); |
| | | __syncthreads(); |
| | | |
| | | const float rr = float(_r[t]); |
| | | const float w = exp(_w[t]); |
| | | float gk = 0; |
| | | |
| | | #pragma unroll |
| | | for (int j = 0; j < _N_; j++) |
| | | { |
| | | float& s = scccc[j]; |
| | | float x = rr * gy[j]; |
| | | |
| | | gk += (u * x + s) * v[j]; |
| | | s = x + s * w; |
| | | } |
| | | _gk[t] = F(gk); |
| | | } |
| | | |
| | | for (int t = t_T_1; t >= t_0; t -= C) |
| | | { |
| | | __syncthreads(); |
| | | r[i] = float(_r[t]); |
| | | k[i] = float(_k[t]); |
| | | w_[i] = exp(_w[t]); |
| | | __syncthreads(); |
| | | |
| | | const float gyy = float(_gy[t]); |
| | | float gv = 0; |
| | | |
| | | #pragma unroll |
| | | for (int j = 0; j < _N_; j++) |
| | | { |
| | | float& s = sdddd[j]; |
| | | float x = gyy * r[j]; |
| | | |
| | | gv += (u_[j] * x + s) * k[j]; |
| | | s = x + s * w_[j]; |
| | | } |
| | | _gv[t] = F(gv); |
| | | } |
| | | } |
| | | |
| | | template <typename F> |
| | | __global__ void kernel_backward_222(const int B, const int T, const int C, const int H, |
| | | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, |
| | | F *__restrict__ const _gw) |
| | | { |
| | | const int b = blockIdx.x / H; |
| | | const int h = blockIdx.x % H; |
| | | const int i = threadIdx.x; |
| | | |
| | | __shared__ float v[_N_], gy[_N_]; |
| | | float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0}; |
| | | |
| | | const int t_0 = b*T*C + h*_N_ + i; |
| | | const int t_1 = t_0 + C; |
| | | const int t_2 = t_0 + 2*C; |
| | | const int t_T_1 = t_0 + (T-1)*C; |
| | | |
| | | for (int t = t_T_1; t > t_1; t -= C) |
| | | { |
| | | __syncthreads(); |
| | | gy[i] = float(_gy[t]); |
| | | v[i] = float(_v[t-2*C]); |
| | | __syncthreads(); |
| | | |
| | | const float r = float(_r[t]); |
| | | const float w = exp(_w[t-C]); |
| | | float sum = 0.0f; |
| | | |
| | | #pragma unroll |
| | | for (int j = 0; j < _N_; j++) |
| | | { |
| | | float& s = saaaa[j]; |
| | | float x = r * gy[j]; |
| | | s = (s + x) * w; |
| | | sum += s * v[j]; |
| | | } |
| | | sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]); |
| | | } |
| | | |
| | | float sss = sbbbb[0]; |
| | | _gw[t_0] = 0; |
| | | _gw[t_1] = F(sss * _w[t_1]); |
| | | |
| | | for (int t = t_2; t < t_T_1; t += C) |
| | | { |
| | | __syncthreads(); |
| | | gy[i] = float(_gy[t]); |
| | | v[i] = float(_v[t-2*C]); |
| | | __syncthreads(); |
| | | |
| | | const float w = exp(_w[t-C]); |
| | | const float k = float(_k[t-2*C]); |
| | | float sum = 0.0f; |
| | | |
| | | #pragma unroll |
| | | for (int j = 0; j < _N_; j++) |
| | | { |
| | | float& s = scccc[j]; |
| | | float x = k * v[j]; |
| | | s = (s + x) * w; |
| | | sum += s * gy[j]; |
| | | } |
| | | sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t])); |
| | | _gw[t] = F(sss * _w[t]); |
| | | } |
| | | _gw[t_T_1] = 0; |
| | | } |
| | | |
| | | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) |
| | | { |
| | | assert(H*_N_ == C); |
| | | assert(_N_%4 == 0); |
| | | kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y); |
| | | } |
| | | |
| | | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) |
| | | { |
| | | assert(H*_N_ == C); |
| | | assert(_N_%4 == 0); |
| | | kernel_backward_111<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu); |
| | | kernel_backward_222<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gw); |
| | | } |
| New file |
| | |
| | | #include <torch/extension.h> |
| | | #include "ATen/ATen.h" |
| | | typedef at::BFloat16 bf16; |
| | | //typedef float bf16; |
| | | |
| | | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); |
| | | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); |
| | | |
| | | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { |
| | | cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>()); |
| | | } |
| | | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { |
| | | cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>()); |
| | | } |
| | | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| | | m.def("forward", &forward, "wkv6 forward"); |
| | | m.def("backward", &backward, "wkv6 backward"); |
| | | } |
| | | |
| | | TORCH_LIBRARY(wkv6, m) { |
| | | m.def("forward", forward); |
| | | m.def("backward", backward); |
| | | } |
| | |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.register import tables |
| | | import base64 |
| | | import gzip |
| | | from dataclasses import dataclass |
| | | from typing import Dict, Iterable, Optional |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import Tensor, nn |
| | | |
| | | |
| | | class LayerNorm(nn.LayerNorm): |
| | | def forward(self, x: Tensor) -> Tensor: |
| | | return super().forward(x.float()).type(x.dtype) |
| | | |
| | | |
| | | class Linear(nn.Linear): |
| | | def forward(self, x: Tensor) -> Tensor: |
| | | return F.linear( |
| | | x, |
| | | self.weight.to(x.dtype), |
| | | None if self.bias is None else self.bias.to(x.dtype), |
| | | ) |
| | | |
| | | |
| | | |
| | | def sense_voice_decode_forward( |
| | | self, |
| | |
| | | |
| | | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| | | tgt, memory = x, xa |
| | | tgt[tgt==-1] = 0 |
| | | tgt[tgt == -1] = 0 |
| | | tgt = ( |
| | | self.token_embedding(tgt) |
| | | + self.positional_embedding[offset : offset + tgt.size(1)] |
| | | + self.positional_embedding[offset: offset + tgt.size(1)] |
| | | ) |
| | | # tgt = self.dropout(tgt) |
| | | |
| | |
| | | |
| | | for layer, block in enumerate(self.blocks): |
| | | x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True) |
| | | |
| | | |
| | | |
| | | x = self.ln(x) |
| | | x = ( |
| | | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) |
| | | ).float() |
| | | |
| | | |
| | | return x |
| | | |
| | | |
| | | |
| | | class MultiHeadAttention(nn.Module): |
| | | def __init__(self, n_state: int, n_head: int): |
| | | super().__init__() |
| | | self.n_head = n_head |
| | | self.query = Linear(n_state, n_state) |
| | | self.key = Linear(n_state, n_state, bias=False) |
| | | self.value = Linear(n_state, n_state) |
| | | self.out = Linear(n_state, n_state) |
| | | |
| | | def forward( |
| | | self, |
| | | x: Tensor, |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | |
| | | q = self.query(x) |
| | | |
| | | if kv_cache is None or xa is None or self.key not in kv_cache: |
| | | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; |
| | | # otherwise, perform key/value projections for self- or cross-attention as usual. |
| | | k = self.key(x if xa is None else xa) |
| | | v = self.value(x if xa is None else xa) |
| | | else: |
| | | # for cross-attention, calculate keys and values once and reuse in subsequent calls. |
| | | k = kv_cache[self.key] |
| | | v = kv_cache[self.value] |
| | | |
| | | wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask) |
| | | return self.out(wv), qk |
| | | |
| | | def qkv_attention( |
| | | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs, |
| | | ): |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | n_batch, n_ctx, n_state = q.shape |
| | | scale = (n_state // self.n_head) ** -0.25 |
| | | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale |
| | | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale |
| | | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) |
| | | |
| | | qk = q @ k |
| | | if mask is not None: |
| | | if not is_pad_mask: |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | else: |
| | | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) |
| | | min_value = float( |
| | | np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min |
| | | ) |
| | | qk = qk.masked_fill(mask, min_value) |
| | | |
| | | qk = qk.float() |
| | | |
| | | w = F.softmax(qk, dim=-1).to(q.dtype) |
| | | if mask is not None and is_pad_mask: |
| | | w = w.masked_fill(mask, 0.0) |
| | | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() |
| | | |
| | | |
| | | from funasr.models.sense_voice.rwkv_v6 import RWKVLayer |
| | | from omegaconf import OmegaConf |
| | | |
| | | |
| | | class ResidualAttentionBlockRWKV(nn.Module): |
| | | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, layer_id=0, **kwargs): |
| | | super().__init__() |
| | | |
| | | rwkv_cfg = kwargs.get("rwkv_cfg", {}) |
| | | args = OmegaConf.create(rwkv_cfg) |
| | | self.attn = RWKVLayer(args=args, layer_id=layer_id) |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | self.attn.to(torch.bfloat16) |
| | | |
| | | self.ln0 = None |
| | | if layer_id == 0 and not args.get("ln0", True): |
| | | self.ln0 = LayerNorm(args.n_embd) |
| | | if args.get("init_rwkv", True): |
| | | print("init_rwkv") |
| | | layer_id = 0 |
| | | scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7 |
| | | nn.init.constant_(self.ln0.weight, scale) |
| | | self.layer_id = layer_id |
| | | self.args = args |
| | | |
| | | self.ln1 = None |
| | | if not args.get("ln1", True): |
| | | self.ln1 = LayerNorm(args.n_embd) |
| | | # init |
| | | if args.get("init_rwkv", True): |
| | | print("init_rwkv") |
| | | scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7 |
| | | nn.init.constant_(self.ln1.weight, scale) |
| | | |
| | | self.cross_attn = ( |
| | | MultiHeadAttention(n_state, n_head) if cross_attention else None |
| | | ) |
| | | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None |
| | | |
| | | n_mlp = n_state * 4 |
| | | self.mlp = nn.Sequential( |
| | | Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) |
| | | ) |
| | | self.mlp_ln = LayerNorm(n_state) |
| | | |
| | | def forward( |
| | | self, |
| | | x: Tensor, |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) |
| | | |
| | | if self.layer_id == 0 and self.ln0 is not None: |
| | | x = self.ln0(x) |
| | | |
| | | if self.ln1 is None: |
| | | x = x + self.attn(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | else: |
| | | x = x + self.attn(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | |
| | | if self.cross_attn: |
| | | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0] |
| | | x = x + self.mlp(self.mlp_ln(x)) |
| | | |
| | | return x |
| | | |
| | | @tables.register("decoder_classes", "SenseVoiceDecoder") |
| | | class SenseVoiceDecoder(nn.Module): |
| | | def __init__( |
| | | self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs |
| | | ): |
| | | super().__init__() |
| | | |
| | | self.token_embedding = nn.Embedding(n_vocab, n_state) |
| | | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) |
| | | |
| | | |
| | | self.blocks = nn.ModuleList( |
| | | [ |
| | | ResidualAttentionBlockRWKV(n_state, n_head, cross_attention=True, layer_id=i, **kwargs) |
| | | for i in range(n_layer) |
| | | ] |
| | | ) |
| | | self.ln = LayerNorm(n_state) |
| | | |
| | | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) |
| | | self.register_buffer("mask", mask, persistent=False) |
| | | |
| | | self.use_padmask = kwargs.get("use_padmask", True) |
| | | # def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): |
| | | # """ |
| | | # x : torch.LongTensor, shape = (batch_size, <= n_ctx) |
| | | # the text tokens |
| | | # xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) |
| | | # the encoded audio features to be attended on |
| | | # """ |
| | | # offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| | | # x = ( |
| | | # self.token_embedding(x) |
| | | # + self.positional_embedding[offset: offset + x.shape[-1]] |
| | | # ) |
| | | # x = x.to(xa.dtype) |
| | | # |
| | | # for block in self.blocks: |
| | | # x = block(x, xa, mask=self.mask, kv_cache=kv_cache) |
| | | # |
| | | # x = self.ln(x) |
| | | # logits = ( |
| | | # x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) |
| | | # ).float() |
| | | # |
| | | # return logits |
| | | |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | xa: torch.Tensor, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | """Forward decoder. |
| | | |
| | | Args: |
| | | hs_pad: encoded memory, float32 (batch, maxlen_in, feat) |
| | | hlens: (batch) |
| | | ys_in_pad: |
| | | input token ids, int64 (batch, maxlen_out) |
| | | if input_layer == "embed" |
| | | input tensor (batch, maxlen_out, #mels) in the other cases |
| | | ys_in_lens: (batch) |
| | | Returns: |
| | | (tuple): tuple containing: |
| | | |
| | | x: decoded token score before softmax (batch, maxlen_out, token) |
| | | if use_output_layer is True, |
| | | olens: (batch, ) |
| | | """ |
| | | # import pdb;pdb.set_trace() |
| | | use_padmask = self.use_padmask |
| | | hlens = kwargs.get("hlens", None) |
| | | |
| | | ys_in_lens = kwargs.get("ys_in_lens", None) |
| | | |
| | | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| | | tgt, memory = x, xa |
| | | tgt[tgt == -1] = 0 |
| | | tgt = ( |
| | | self.token_embedding(tgt) |
| | | + self.positional_embedding[offset: offset + tgt.size(1)] |
| | | ) |
| | | # tgt = self.dropout(tgt) |
| | | |
| | | x = tgt.to(memory.dtype) |
| | | |
| | | if use_padmask and hlens is not None: |
| | | memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device) |
| | | else: |
| | | memory_mask = None |
| | | |
| | | for layer, block in enumerate(self.blocks): |
| | | x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True) |
| | | |
| | | x = self.ln(x) |
| | | x = ( |
| | | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) |
| | | ).float() |
| | | |
| | | return x |
| | |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | | |
| | | |
| | | |
| | | @tables.register("model_classes", "SenseVoiceRWKV") |
| | | class SenseVoiceRWKV(nn.Module): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__() |
| | | |
| | | dims = kwargs.get("dims", {}) |
| | | dims = whisper.model.ModelDimensions(**dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | |
| | | # encoder |
| | | model.encoder.downsample_rate = kwargs.get("downsample_rate", 4) |
| | | model.encoder.use_padmask = kwargs.get("use_padmask", True) |
| | | from .encoder import sense_voice_encode_forward |
| | | model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder) |
| | | |
| | | # decoder |
| | | del model.decoder |
| | | decoder = kwargs.get("decoder", "SenseVoiceDecoder") |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class(n_vocab=dims.n_vocab, |
| | | n_ctx=dims.n_text_ctx, |
| | | n_state=dims.n_text_state, |
| | | n_head=dims.n_text_head, |
| | | n_layer=dims.n_text_layer, |
| | | **kwargs.get("decoder_conf")) |
| | | model.decoder = decoder |
| | | |
| | | self.model = model |
| | | |
| | | self.encoder_output_size = self.model.dims.n_audio_state |
| | | |
| | | self.activation_checkpoint = kwargs.get("activation_checkpoint", False) |
| | | self.ignore_id = kwargs.get("ignore_id", -1) |
| | | self.vocab_size = kwargs.get("vocab_size", -1) |
| | | self.length_normalized_loss = kwargs.get("length_normalized_loss", True) |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=self.vocab_size, |
| | | padding_idx=self.ignore_id, |
| | | smoothing=kwargs.get("lsm_weight", 0.0), |
| | | normalize_length=self.length_normalized_loss, |
| | | ) |
| | | |
| | | specaug = kwargs.get("specaug", None) |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**kwargs.get("specaug_conf", {})) |
| | | self.specaug = specaug |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size = speech.shape[0] |
| | | |
| | | if self.activation_checkpoint: |
| | | from torch.utils.checkpoint import checkpoint |
| | | encoder_out, encoder_out_lens = checkpoint(self.encode, speech, speech_lengths, use_reentrant=False) |
| | | else: |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask |
| | | ) |
| | | loss = loss_att |
| | | stats = {} |
| | | stats["acc"] = acc_att |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | stats["batch_size"] = batch_size |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + 1).sum()) |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, |
| | | ): |
| | | """Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | ind: int |
| | | """ |
| | | with autocast(False): |
| | | # Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| | | |
| | | # Forward encoder |
| | | encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths) |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | **kwargs, |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | stats = {} |
| | | |
| | | # 1. Forward decoder |
| | | decoder_out = self.model.decoder( |
| | | x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | mask = torch.ones_like(ys_pad) * (-1) |
| | | ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64) |
| | | ys_pad_mask[ys_pad_mask == 0] = -1 |
| | | loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:]) |
| | | |
| | | with torch.no_grad(): |
| | | preds = torch.argmax(decoder_out, -1) |
| | | acc_att = compute_accuracy(preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id) |
| | | |
| | | return loss_att, acc_att, None, None |
| | | |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | | frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)) |
| | | self.frontend = frontend |
| | | else: |
| | | frontend = frontend if frontend is not None else self.frontend |
| | | |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is None: |
| | | speech_lengths = speech.shape[1] |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, |
| | | fs=frontend.fs if hasattr(frontend, "fs") else 16000, |
| | | audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 |
| | | lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1 |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000 |
| | | |
| | | speech = speech.to(device=kwargs["device"])[0, :, :] |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | DecodingOptions = kwargs.get("DecodingOptions", {}) |
| | | task = DecodingOptions.get("task", "ASR") |
| | | if isinstance(task, str): |
| | | task = [task] |
| | | task = "".join([f"<|{x}|>" for x in task]) |
| | | initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}") |
| | | DecodingOptions["initial_prompt"] = initial_prompt |
| | | |
| | | language = DecodingOptions.get("language", None) |
| | | language = None if language == "auto" else language |
| | | DecodingOptions["language"] = language |
| | | |
| | | DecodingOptions["vocab_path"] = kwargs.get("vocab_path", None) |
| | | |
| | | if "without_timestamps" not in DecodingOptions: |
| | | DecodingOptions["without_timestamps"] = True |
| | | |
| | | options = whisper.DecodingOptions(**DecodingOptions) |
| | | |
| | | result = whisper.decode(self.model, speech, options) |
| | | text = f"{result.text}" |
| | | results = [] |
| | | result_i = {"key": key[0], "text": text} |
| | | |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| New file |
| | |
| | | ######################################################################################################## |
| | | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM |
| | | ######################################################################################################## |
| | | |
| | | import os, math, gc, importlib |
| | | import torch |
| | | # torch._C._jit_set_profiling_executor(True) |
| | | # torch._C._jit_set_profiling_mode(True) |
| | | import torch.nn as nn |
| | | from torch.nn import functional as F |
| | | |
| | | |
| | | |
| | | def __nop(ob): |
| | | return ob |
| | | |
| | | |
| | | MyModule = nn.Module |
| | | MyFunction = __nop |
| | | if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1": |
| | | MyModule = torch.jit.ScriptModule |
| | | MyFunction = torch.jit.script_method |
| | | |
| | | ######################################################################################################## |
| | | # CUDA Kernel |
| | | ######################################################################################################## |
| | | |
| | | wkv6_cuda = None |
| | | |
| | | def load_rwkv_kernel(HEAD_SIZE: int=64, RWKV_CTXLEN: int=512,): |
| | | from torch.utils.cpp_extension import load |
| | | global wkv6_cuda |
| | | |
| | | |
| | | if wkv6_cuda is not None: |
| | | return |
| | | |
| | | absolute_file_path = os.path.abspath(__file__) |
| | | cur_dir = os.path.dirname(absolute_file_path) |
| | | wkv6_cuda = load(name="wkv6", sources=[f"{cur_dir}/cuda/wkv6_op.cpp", f"{cur_dir}/cuda/wkv6_cuda.cu"], |
| | | verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", |
| | | "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", |
| | | f"-D_T_={RWKV_CTXLEN}"]) |
| | | |
| | | # dtype = torch.float |
| | | dtype = torch.bfloat16 |
| | | class WKV_6(torch.autograd.Function): |
| | | @staticmethod |
| | | def forward(ctx, B, T, C, H, r, k, v, w, u): |
| | | with torch.no_grad(): |
| | | # assert r.dtype == torch.bfloat16 |
| | | # assert k.dtype == torch.bfloat16 |
| | | # assert v.dtype == torch.bfloat16 |
| | | # assert w.dtype == torch.bfloat16 |
| | | # assert u.dtype == torch.bfloat16 |
| | | # assert HEAD_SIZE == C // H |
| | | ctx.B = B |
| | | ctx.T = T |
| | | ctx.C = C |
| | | ctx.H = H |
| | | assert r.is_contiguous() |
| | | assert k.is_contiguous() |
| | | assert v.is_contiguous() |
| | | assert w.is_contiguous() |
| | | assert u.is_contiguous() |
| | | ew = (-torch.exp(w.float())).contiguous() |
| | | ctx.save_for_backward(r, k, v, ew, u) |
| | | y = torch.empty((B, T, C), device=r.device, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y) |
| | | return y |
| | | |
| | | @staticmethod |
| | | def backward(ctx, gy): |
| | | with torch.no_grad(): |
| | | # assert gy.dtype == torch.bfloat16 |
| | | B = ctx.B |
| | | T = ctx.T |
| | | C = ctx.C |
| | | H = ctx.H |
| | | assert gy.is_contiguous() |
| | | r, k, v, ew, u = ctx.saved_tensors |
| | | gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | gw = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | gu = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | wkv6_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu) |
| | | gu = torch.sum(gu, 0).view(H, C // H) |
| | | return (None, None, None, None, gr, gk, gv, gw, gu) |
| | | |
| | | |
| | | def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u): |
| | | return WKV_6.apply(B, T, C, H, r, k, v, w, u) |
| | | |
| | | |
| | | class RWKV_Tmix_x060(MyModule): |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | |
| | | load_rwkv_kernel(args.head_size_a, args.ctx_len) |
| | | |
| | | self.layer_id = layer_id |
| | | |
| | | self.head_size = args.head_size_a |
| | | self.n_head = args.dim_att // self.head_size |
| | | assert args.dim_att % self.n_head == 0 |
| | | |
| | | with torch.no_grad(): |
| | | ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 |
| | | ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 |
| | | ddd = torch.ones(1, 1, args.n_embd) |
| | | for i in range(args.n_embd): |
| | | ddd[0, 0, i] = i / args.n_embd |
| | | |
| | | # fancy time_mix |
| | | self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_v = nn.Parameter(1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)) |
| | | self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) |
| | | self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) |
| | | |
| | | D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g |
| | | self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5)) |
| | | self.time_maa_w2 = nn.Parameter(torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)) |
| | | |
| | | # fancy time_decay |
| | | decay_speed = torch.ones(args.dim_att) |
| | | for n in range(args.dim_att): |
| | | decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) |
| | | self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att)) |
| | | |
| | | D_DECAY_LORA = 64 |
| | | self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA)) |
| | | self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)) |
| | | |
| | | tmp = torch.zeros(args.dim_att) |
| | | for n in range(args.dim_att): |
| | | zigzag = ((n + 1) % 3 - 1) * 0.1 |
| | | tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag |
| | | |
| | | self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size)) |
| | | |
| | | self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) |
| | | self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | self.key = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | |
| | | self.value = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) |
| | | self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | self.ln_x = nn.GroupNorm(self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor ** 2)) |
| | | |
| | | @MyFunction |
| | | def jit_func(self, x): |
| | | B, T, C = x.size() |
| | | |
| | | xx = self.time_shift(x) - x |
| | | |
| | | xxx = x + xx * self.time_maa_x |
| | | xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1) |
| | | xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1) |
| | | mw, mk, mv, mr, mg = xxx.unbind(dim=0) |
| | | |
| | | xw = x + xx * (self.time_maa_w + mw) |
| | | xk = x + xx * (self.time_maa_k + mk) |
| | | xv = x + xx * (self.time_maa_v + mv) |
| | | xr = x + xx * (self.time_maa_r + mr) |
| | | xg = x + xx * (self.time_maa_g + mg) |
| | | |
| | | r = self.receptance(xr) |
| | | k = self.key(xk) |
| | | v = self.value(xv) |
| | | g = F.silu(self.gate(xg)) |
| | | |
| | | ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 |
| | | w = self.time_decay + ww |
| | | |
| | | return r, k, v, g, w |
| | | |
| | | @MyFunction |
| | | def jit_func_2(self, x, g): |
| | | B, T, C = x.size() |
| | | x = x.view(B * T, C) |
| | | |
| | | x = self.ln_x(x).view(B, T, C) |
| | | x = self.output(x * g) |
| | | return x |
| | | |
| | | def forward(self, x): |
| | | B, T, C = x.size() |
| | | H = self.n_head |
| | | |
| | | r, k, v, g, w = self.jit_func(x) |
| | | x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa) |
| | | |
| | | return self.jit_func_2(x, g) |
| | | |
| | | class RWKV_CMix_x060(MyModule): |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | self.layer_id = layer_id |
| | | self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) |
| | | |
| | | with torch.no_grad(): # fancy init of time_mix |
| | | ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 |
| | | ddd = torch.ones(1, 1, args.n_embd) |
| | | for i in range(args.n_embd): |
| | | ddd[0, 0, i] = i / args.n_embd |
| | | self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | |
| | | self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) |
| | | self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) |
| | | self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) |
| | | |
| | | @MyFunction |
| | | def forward(self, x): |
| | | xx = self.time_shift(x) - x |
| | | xk = x + xx * self.time_maa_k |
| | | xr = x + xx * self.time_maa_r |
| | | |
| | | k = self.key(xk) |
| | | k = torch.relu(k) ** 2 |
| | | kv = self.value(k) |
| | | return torch.sigmoid(self.receptance(xr)) * kv |
| | | |
| | | |
| | | class Block(nn.Module): |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | self.layer_id = layer_id |
| | | |
| | | self.ln1 = nn.LayerNorm(args.n_embd) |
| | | self.ln2 = nn.LayerNorm(args.n_embd) |
| | | |
| | | if self.layer_id == 0: |
| | | self.ln0 = nn.LayerNorm(args.n_embd) |
| | | |
| | | |
| | | self.att = RWKV_Tmix_x060(args, layer_id) |
| | | |
| | | self.ffn = RWKV_CMix_x060(args, layer_id) |
| | | |
| | | |
| | | if args.dropout > 0: |
| | | self.drop0 = nn.Dropout(p=args.dropout) |
| | | self.drop1 = nn.Dropout(p=args.dropout) |
| | | |
| | | def forward(self, x, x_emb=None): |
| | | args = self.args |
| | | B, T, C = x.size() |
| | | if self.layer_id == 0: |
| | | x = self.ln0(x) |
| | | |
| | | |
| | | if self.args.dropout == 0: |
| | | if self.layer_id == 0 and args.pre_ffn > 0: |
| | | x = x + self.ffnPre(self.ln1(x)) |
| | | else: |
| | | x = x + self.att(self.ln1(x)) |
| | | x = x + self.ffn(self.ln2(x)) |
| | | else: |
| | | if self.layer_id == 0 and args.pre_ffn > 0: |
| | | x = self.drop0(x + self.ffnPre(self.ln1(x))) |
| | | else: |
| | | x = self.drop0(x + self.att(self.ln1(x))) |
| | | x = self.drop1(x + self.ffn(self.ln2(x))) |
| | | |
| | | return x |
| | | |
| | | |
| | | class RWKVLayer(nn.Module): |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | self.layer_id = layer_id |
| | | if args.dim_ffn is None: |
| | | args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) |
| | | self.ln0 = None |
| | | if self.layer_id == 0 and args.get("ln0", True): |
| | | self.ln0 = nn.LayerNorm(args.n_embd) |
| | | |
| | | self.ln1 = None |
| | | if args.get("ln1", True): |
| | | self.ln1 = nn.LayerNorm(args.n_embd) |
| | | self.ln2 = nn.LayerNorm(args.n_embd) |
| | | |
| | | |
| | | self.att = RWKV_Tmix_x060(args, layer_id) |
| | | |
| | | self.ffn = RWKV_CMix_x060(args, layer_id) |
| | | |
| | | if args.dropout > 0: |
| | | self.drop0 = nn.Dropout(p=args.dropout) |
| | | self.drop1 = nn.Dropout(p=args.dropout) |
| | | |
| | | # init |
| | | if args.get("init_rwkv", True): |
| | | print("init_rwkv") |
| | | nn.init.orthogonal_(self.att.receptance.weight, gain=1) |
| | | nn.init.orthogonal_(self.att.key.weight, gain=0.1) |
| | | nn.init.orthogonal_(self.att.value.weight, gain=1) |
| | | nn.init.orthogonal_(self.att.gate.weight, gain=0.1) |
| | | nn.init.zeros_(self.att.output.weight) |
| | | |
| | | nn.init.orthogonal_(self.ffn.key.weight, gain=1) |
| | | nn.init.zeros_(self.ffn.value.weight) |
| | | nn.init.zeros_(self.ffn.receptance.weight) |
| | | scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7 |
| | | nn.init.constant_(self.ln2.weight, scale) |
| | | if self.ln0 is not None: |
| | | nn.init.constant_(self.ln0.weight, scale) |
| | | if self.ln1 is not None: |
| | | nn.init.constant_(self.ln1.weight, scale) |
| | | |
| | | def forward(self, x, x_emb=None, mask=None, **kwargs): |
| | | |
| | | args = self.args |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | x = x.bfloat16() |
| | | B, T, C = x.size() |
| | | if self.layer_id == 0 and self.ln0 is not None: |
| | | x = self.ln0(x) |
| | | |
| | | if self.args.dropout == 0: |
| | | if self.ln1 is None: |
| | | x = x + self.att(x) |
| | | else: |
| | | x = x + self.att(self.ln1(x)) |
| | | x = x + self.ffn(self.ln2(x)) |
| | | else: |
| | | if self.ln1 is None: |
| | | x = self.drop0(x + self.att(x)) |
| | | else: |
| | | x = self.drop0(x + self.att(self.ln1(x))) |
| | | x = self.drop1(x + self.ffn(self.ln2(x))) |
| | | |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | x = x.to(torch.float32) |
| | | return x |
| | | |
| | | |
| | | class RWKV(nn.Module): |
| | | def __init__(self, args): |
| | | super().__init__() |
| | | self.args = args |
| | | if not hasattr(args, 'dim_att'): |
| | | args.dim_att = args.n_embd |
| | | if not hasattr(args, 'dim_ffn'): |
| | | if '-f4' in os.environ["RWKV_MY_TESTING"]: |
| | | args.dim_ffn = int((args.n_embd * 4) // 32 * 32) |
| | | else: |
| | | args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size |
| | | if not hasattr(args, 'tiny_att_layer'): |
| | | args.tiny_att_layer = -1 |
| | | if not hasattr(args, 'tiny_att_dim'): |
| | | args.tiny_att_dim = -1 |
| | | assert args.n_embd % 32 == 0 |
| | | assert args.dim_att % 32 == 0 |
| | | assert args.dim_ffn % 32 == 0 |
| | | |
| | | self.emb = nn.Embedding(args.vocab_size, args.n_embd) |
| | | |
| | | self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) |
| | | |
| | | self.ln_out = nn.LayerNorm(args.n_embd) |
| | | self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) |
| | | |
| | | |
| | | if args.dropout > 0: |
| | | self.drop0 = nn.Dropout(p=args.dropout) |
| | | |
| | | |
| | | def forward(self, idx): |
| | | args = self.args |
| | | B, T = idx.size() |
| | | assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." |
| | | |
| | | x = self.emb(idx) |
| | | x_emb = x |
| | | |
| | | if args.dropout > 0: |
| | | x = self.drop0(x) |
| | | if args.tiny_att_dim > 0: |
| | | for block in self.blocks: |
| | | if args.grad_cp == 1: |
| | | x = deepspeed.checkpointing.checkpoint(block, x, x_emb) |
| | | else: |
| | | x = block(x, x_emb) |
| | | else: |
| | | for block in self.blocks: |
| | | if args.grad_cp == 1: |
| | | x = deepspeed.checkpointing.checkpoint(block, x) |
| | | else: |
| | | x = block(x) |
| | | |
| | | x = self.ln_out(x) |
| | | |
| | | if args.head_qk > 0: |
| | | q = self.head_q(x)[:, :T, :] |
| | | k = self.head_k(x)[:, :T, :] |
| | | c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) |
| | | c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) |
| | | |
| | | if "32" in os.environ["RWKV_FLOAT_MODE"]: |
| | | c = c @ F.one_hot(idx, num_classes=args.vocab_size) |
| | | elif os.environ["RWKV_FLOAT_MODE"] == "fp16": |
| | | c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() |
| | | elif os.environ["RWKV_FLOAT_MODE"] == "bf16": |
| | | c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() |
| | | |
| | | x = self.head(x) + c |
| | | else: |
| | | x = self.head(x) |
| | | |
| | | return x |
| | |
| | | self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool |
| | | ) |
| | | all_heads[self.dims.n_text_layer // 2 :] = True |
| | | self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) |
| | | # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) |
| | | # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense() |
| | | # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False) |
| | | |
| | | def set_alignment_heads(self, dump: bytes): |
| | | array = np.frombuffer( |