From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/transformer/decoder.py | 318 +++++++++++++++++++++++++++--------------------------
1 files changed, 162 insertions(+), 156 deletions(-)
diff --git a/funasr/models/transformer/decoder.py b/funasr/models/transformer/decoder.py
index 820de4a..7d849f3 100644
--- a/funasr/models/transformer/decoder.py
+++ b/funasr/models/transformer/decoder.py
@@ -28,6 +28,7 @@
from funasr.register import tables
+
class DecoderLayer(nn.Module):
"""Single decoder layer module.
@@ -51,14 +52,14 @@
"""
def __init__(
- self,
- size,
- self_attn,
- src_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
+ self,
+ size,
+ self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
@@ -115,9 +116,7 @@
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
- tgt_concat = torch.cat(
- (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
- )
+ tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
@@ -128,9 +127,7 @@
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
- x_concat = torch.cat(
- (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
- )
+ x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
@@ -146,6 +143,35 @@
if cache is not None:
x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, memory, memory_mask
+
+
+class DecoderLayerExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.self_attn = model.self_attn
+ self.src_attn = model.src_attn
+ self.feed_forward = model.feed_forward
+ self.norm1 = model.norm1
+ self.norm2 = model.norm2
+ self.norm3 = model.norm3
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+ residual = tgt
+ tgt = self.norm1(tgt)
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
+
+ residual = x
+ x = self.norm2(x)
+
+ x = residual + self.src_attn(x, memory, memory, memory_mask)
+
+ residual = x
+ x = self.norm3(x)
+ x = residual + self.feed_forward(x)
return x, tgt_mask, memory, memory_mask
@@ -173,15 +199,15 @@
"""
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
):
super().__init__()
attention_dim = encoder_output_size
@@ -214,11 +240,11 @@
self.decoders = None
def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -246,20 +272,14 @@
tgt_mask = tgt_mask & m
memory = hs_pad
- memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
- memory.device
- )
+ memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
# Padding for Longformer
if memory_mask.shape[-1] != memory.shape[1]:
padlen = memory.shape[1] - memory_mask.shape[-1]
- memory_mask = torch.nn.functional.pad(
- memory_mask, (0, padlen), "constant", False
- )
+ memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
x = self.embed(tgt)
- x, tgt_mask, memory, memory_mask = self.decoders(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask = self.decoders(x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
@@ -269,11 +289,11 @@
return x, olens
def forward_one_step(
- self,
- tgt: torch.Tensor,
- tgt_mask: torch.Tensor,
- memory: torch.Tensor,
- cache: List[torch.Tensor] = None,
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
@@ -293,9 +313,7 @@
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
- x, tgt_mask, memory, memory_mask = decoder(
- x, tgt_mask, memory, None, cache=c
- )
+ x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, None, cache=c)
new_cache.append(x)
if self.normalize_before:
@@ -310,13 +328,11 @@
def score(self, ys, state, x):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
- logp, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
- )
+ logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
return logp.squeeze(0), state
def batch_score(
- self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+ self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
@@ -340,8 +356,7 @@
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
- torch.stack([states[b][i] for b in range(n_batch)])
- for i in range(n_layers)
+ torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)
]
# batch decoding
@@ -352,24 +367,25 @@
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
+
@tables.register("decoder_classes", "TransformerDecoder")
class TransformerDecoder(BaseTransformerDecoder):
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
):
super().__init__(
vocab_size=vocab_size,
@@ -387,12 +403,8 @@
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
- MultiHeadedAttention(
- attention_heads, attention_dim, self_attention_dropout_rate
- ),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
+ MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
+ MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
@@ -404,24 +416,24 @@
@tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder")
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- conv_wshare: int = 4,
- conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
- conv_usebias: int = False,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ conv_wshare: int = 4,
+ conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+ conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
@@ -452,9 +464,7 @@
use_kernel_mask=True,
use_bias=conv_usebias,
),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
+ MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
@@ -462,27 +472,28 @@
),
)
+
@tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder")
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- conv_wshare: int = 4,
- conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
- conv_usebias: int = False,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ conv_wshare: int = 4,
+ conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+ conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
@@ -513,9 +524,7 @@
use_kernel_mask=True,
use_bias=conv_usebias,
),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
+ MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
@@ -527,24 +536,24 @@
@tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder")
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- conv_wshare: int = 4,
- conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
- conv_usebias: int = False,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ conv_wshare: int = 4,
+ conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+ conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
@@ -575,9 +584,7 @@
use_kernel_mask=True,
use_bias=conv_usebias,
),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
+ MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
@@ -585,27 +592,28 @@
),
)
+
@tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder")
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- conv_wshare: int = 4,
- conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
- conv_usebias: int = False,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ conv_wshare: int = 4,
+ conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+ conv_usebias: int = False,
):
if len(conv_kernel_length) != num_blocks:
raise ValueError(
@@ -636,9 +644,7 @@
use_kernel_mask=True,
use_bias=conv_usebias,
),
- MultiHeadedAttention(
- attention_heads, attention_dim, src_attention_dropout_rate
- ),
+ MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
--
Gitblit v1.9.1