From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/conformer_rwkv/decoder.py | 226 +++++++++++++++++++++++++++++---------------------------
1 files changed, 118 insertions(+), 108 deletions(-)
diff --git a/funasr/models/conformer_rwkv/decoder.py b/funasr/models/conformer_rwkv/decoder.py
index d7f113d..4c41049 100644
--- a/funasr/models/conformer_rwkv/decoder.py
+++ b/funasr/models/conformer_rwkv/decoder.py
@@ -28,6 +28,12 @@
from omegaconf import OmegaConf
from funasr.register import tables
+
+class LayerNorm(nn.LayerNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
class DecoderLayer(nn.Module):
"""Single decoder layer module.
@@ -51,21 +57,22 @@
"""
def __init__(
- self,
- size,
- self_attn,
- src_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- layer_id=None,
- args={},
+ self,
+ size,
+ # self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ layer_id=None,
+ args={},
+ **kwargs,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
- self.self_attn = self_attn.to(torch.bfloat16)
+ # self.self_attn = self_attn.to(torch.bfloat16)
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
@@ -78,6 +85,20 @@
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
self.layer_id = layer_id
+
+ if args.get("version", "v4") == "v4":
+ from funasr.models.sense_voice.rwkv_v4 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix
+ elif args.get("version", "v5") == "v5":
+ from funasr.models.sense_voice.rwkv_v5 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix
+ else:
+ from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
+ # self.attn = RWKVLayer(args=args, layer_id=layer_id)
+ self.self_attn = RWKV_Tmix(args, layer_id=layer_id)
+
+ self.args = args
self.ln0 = None
if self.layer_id == 0 and not args.get("ln0", True):
self.ln0 = LayerNorm(args.n_embd)
@@ -86,13 +107,25 @@
layer_id = 0
scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
nn.init.constant_(self.ln0.weight, scale)
-
+
# init
if args.get("init_rwkv", True):
print("init_rwkv")
scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
nn.init.constant_(self.norm1.weight, scale)
- nn.init.constant_(self.self_attn.ln2.weight, scale)
+ # nn.init.constant_(self.self_attn.ln2.weight, scale)
+
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ nn.init.orthogonal_(self.self_attn.receptance.weight, gain=1)
+ nn.init.orthogonal_(self.self_attn.key.weight, gain=0.1)
+ nn.init.orthogonal_(self.self_attn.value.weight, gain=1)
+ nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1)
+ nn.init.zeros_(self.self_attn.output.weight)
+
+ if args.get("datatype", "bf16") == "bf16":
+ self.self_attn.to(torch.bfloat16)
+ # self.norm1.to(torch.bfloat16)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
@@ -115,28 +148,27 @@
if self.layer_id == 0 and self.ln0 is not None:
tgt = self.ln0(tgt)
-
+
+ if self.args.get("datatype", "bf16") == "bf16":
+ tgt = tgt.bfloat16()
residual = tgt
-
-
+
tgt = self.norm1(tgt)
if cache is None:
-
+
x = residual + self.dropout(self.self_attn(tgt, mask=tgt_mask))
else:
-
+
# tgt_q = tgt[:, -1:, :]
# residual_q = residual[:, -1:, :]
tgt_q_mask = None
-
+
x = residual + self.dropout(self.self_attn(tgt, mask=tgt_q_mask))
x = x[:, -1, :]
-
-
-
+ if self.args.get("datatype", "bf16") == "bf16":
+ x = x.to(torch.float32)
# x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
-
residual = x
x = self.norm2(x)
@@ -145,12 +177,10 @@
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
-
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask
-
class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
@@ -176,15 +206,15 @@
"""
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
):
super().__init__()
attention_dim = encoder_output_size
@@ -217,11 +247,11 @@
self.decoders = None
def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -249,20 +279,14 @@
tgt_mask = tgt_mask & m
memory = hs_pad
- memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
- memory.device
- )
+ memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
- memory_mask = torch.nn.functional.pad(
- memory_mask, (0, padlen), "constant", False
- )
+ memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
x = self.embed(tgt)
- x, tgt_mask, memory, memory_mask = self.decoders(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask = self.decoders(x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
@@ -272,11 +296,11 @@
return x, olens
def forward_one_step(
- self,
- tgt: torch.Tensor,
- tgt_mask: torch.Tensor,
- memory: torch.Tensor,
- cache: List[torch.Tensor] = None,
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
@@ -296,9 +320,7 @@
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
- x, tgt_mask, memory, memory_mask = decoder(
- x, tgt_mask, memory, None, cache=c
- )
+ x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, None, cache=c)
new_cache.append(x)
if self.normalize_before:
@@ -313,13 +335,11 @@
def score(self, ys, state, x):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
- logp, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
- )
+ logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
return logp.squeeze(0), state
def batch_score(
- self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
@@ -343,8 +363,7 @@
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
- torch.stack([states[b][i] for b in range(n_batch)])
- for i in range(n_layers)
+ torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)
]
# batch decoding
@@ -355,25 +374,26 @@
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
+
@tables.register("decoder_classes", "TransformerRWKVDecoder")
class TransformerRWKVDecoder(BaseTransformerDecoder):
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- **kwargs,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ **kwargs,
):
super().__init__(
vocab_size=vocab_size,
@@ -385,19 +405,17 @@
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
- from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+ # from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+
rwkv_cfg = kwargs.get("rwkv_cfg", {})
args = OmegaConf.create(rwkv_cfg)
- # self.attn = RWKVLayer(args=args, layer_id=layer_id)
+
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
- RWKVLayer(args=args, layer_id=lnum),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
+ MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
@@ -406,18 +424,18 @@
args=args,
),
)
-
+
# init
if args.get("init_rwkv", True):
print("init_rwkv")
nn.init.uniform_(self.embed[0].weight, a=-1e-4, b=1e-4)
def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -445,20 +463,14 @@
tgt_mask = tgt_mask & m
memory = hs_pad
- memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
- memory.device
- )
+ memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
- memory_mask = torch.nn.functional.pad(
- memory_mask, (0, padlen), "constant", False
- )
+ memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
x = self.embed(tgt)
- x, tgt_mask, memory, memory_mask = self.decoders(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask = self.decoders(x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
@@ -468,11 +480,11 @@
return x, olens
def forward_one_step(
- self,
- tgt: torch.Tensor,
- tgt_mask: torch.Tensor,
- memory: torch.Tensor,
- cache: List[torch.Tensor] = None,
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
@@ -492,9 +504,7 @@
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
- x, tgt_mask, memory, memory_mask = decoder(
- x, tgt_mask, memory, None, cache=c
- )
+ x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, None, cache=c)
new_cache.append(x)
if self.normalize_before:
@@ -504,4 +514,4 @@
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
- return y, new_cache
\ No newline at end of file
+ return y, new_cache
--
Gitblit v1.9.1