| | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | class LayerNorm(nn.LayerNorm): |
| | | def forward(self, x): |
| | | return super().forward(x.float()).type(x.dtype) |
| | | |
| | | |
| | | class DecoderLayer(nn.Module): |
| | | """Single decoder layer module. |
| | | |
| | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | self_attn, |
| | | # self_attn, |
| | | src_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | |
| | | concat_after=False, |
| | | layer_id=None, |
| | | args={}, |
| | | **kwargs, |
| | | ): |
| | | """Construct an DecoderLayer object.""" |
| | | super(DecoderLayer, self).__init__() |
| | | self.size = size |
| | | self.self_attn = self_attn.to(torch.bfloat16) |
| | | # self.self_attn = self_attn.to(torch.bfloat16) |
| | | self.src_attn = src_attn |
| | | self.feed_forward = feed_forward |
| | | self.norm1 = LayerNorm(size) |
| | |
| | | self.concat_linear1 = nn.Linear(size + size, size) |
| | | self.concat_linear2 = nn.Linear(size + size, size) |
| | | self.layer_id = layer_id |
| | | |
| | | if args.get("version", "v4") == "v4": |
| | | from funasr.models.sense_voice.rwkv_v4 import RWKVLayer |
| | | from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix |
| | | elif args.get("version", "v5") == "v5": |
| | | from funasr.models.sense_voice.rwkv_v5 import RWKVLayer |
| | | from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix |
| | | else: |
| | | from funasr.models.sense_voice.rwkv_v6 import RWKVLayer |
| | | from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix |
| | | # self.attn = RWKVLayer(args=args, layer_id=layer_id) |
| | | self.self_attn = RWKV_Tmix(args, layer_id=layer_id) |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | self.self_attn.to(torch.bfloat16) |
| | | # self.norm1.to(torch.bfloat16) |
| | | self.args = args |
| | | self.ln0 = None |
| | | if self.layer_id == 0 and not args.get("ln0", True): |
| | | self.ln0 = LayerNorm(args.n_embd) |
| | |
| | | print("init_rwkv") |
| | | scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7 |
| | | nn.init.constant_(self.norm1.weight, scale) |
| | | nn.init.constant_(self.self_attn.ln2.weight, scale) |
| | | # nn.init.constant_(self.self_attn.ln2.weight, scale) |
| | | |
| | | if args.get("init_rwkv", True): |
| | | print("init_rwkv") |
| | | nn.init.orthogonal_(self.self_attn.receptance.weight, gain=1) |
| | | nn.init.orthogonal_(self.self_attn.key.weight, gain=0.1) |
| | | nn.init.orthogonal_(self.self_attn.value.weight, gain=1) |
| | | nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1) |
| | | nn.init.zeros_(self.self_attn.output.weight) |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): |
| | | """Compute decoded features. |
| | |
| | | if self.layer_id == 0 and self.ln0 is not None: |
| | | tgt = self.ln0(tgt) |
| | | |
| | | if self.args.get("datatype", "bf16") == "bf16": |
| | | tgt = tgt.bfloat16() |
| | | residual = tgt |
| | | |
| | | tgt = self.norm1(tgt) |
| | |
| | | |
| | | x = residual + self.dropout(self.self_attn(tgt, mask=tgt_q_mask)) |
| | | x = x[:, -1, :] |
| | | |
| | | if self.args.get("datatype", "bf16") == "bf16": |
| | | x = x.to(torch.float32) |
| | | # x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) |
| | | |
| | | residual = x |
| | |
| | | pos_enc_class=pos_enc_class, |
| | | normalize_before=normalize_before, |
| | | ) |
| | | from funasr.models.sense_voice.rwkv_v6 import RWKVLayer |
| | | # from funasr.models.sense_voice.rwkv_v6 import RWKVLayer |
| | | |
| | | rwkv_cfg = kwargs.get("rwkv_cfg", {}) |
| | | args = OmegaConf.create(rwkv_cfg) |
| | | # self.attn = RWKVLayer(args=args, layer_id=layer_id) |
| | | |
| | | attention_dim = encoder_output_size |
| | | self.decoders = repeat( |
| | | num_blocks, |
| | | lambda lnum: DecoderLayer( |
| | | attention_dim, |
| | | RWKVLayer(args=args, layer_id=lnum), |
| | | MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |