From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/lcbnet/encoder.py | 120 +++++++++++++++++++++++++++++++----------------------------
1 files changed, 63 insertions(+), 57 deletions(-)
diff --git a/funasr/models/lcbnet/encoder.py b/funasr/models/lcbnet/encoder.py
index c65823c..f5f2497 100644
--- a/funasr/models/lcbnet/encoder.py
+++ b/funasr/models/lcbnet/encoder.py
@@ -21,6 +21,7 @@
from funasr.models.transformer.utils.repeat import repeat
from funasr.register import tables
+
class EncoderLayer(nn.Module):
"""Encoder layer module.
@@ -44,14 +45,14 @@
"""
def __init__(
- self,
- size,
- self_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- stochastic_depth_rate=0.0,
+ self,
+ size,
+ self_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
@@ -109,9 +110,7 @@
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
- x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(x_q, x, x, mask)
- )
+ x = residual + stoch_layer_coeff * self.dropout(self.self_attn(x_q, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
@@ -126,6 +125,7 @@
x = torch.cat([cache, x], dim=1)
return x, mask
+
@tables.register("encoder_classes", "TransformerTextEncoder")
class TransformerTextEncoder(nn.Module):
@@ -154,18 +154,18 @@
"""
def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
):
super().__init__()
self._output_size = output_size
@@ -187,9 +187,7 @@
num_blocks,
lambda lnum: EncoderLayer(
output_size,
- MultiHeadedAttention(
- attention_heads, output_size, attention_dropout_rate
- ),
+ MultiHeadedAttention(attention_heads, output_size, attention_dropout_rate),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
@@ -203,9 +201,9 @@
return self._output_size
def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
@@ -225,8 +223,6 @@
olens = masks.squeeze(1).sum(1)
return xs_pad, olens, None
-
-
@tables.register("encoder_classes", "FusionSANEncoder")
@@ -251,25 +247,32 @@
"""
+
def __init__(
- self,
- size,
- attention_heads,
- attention_dim,
- linear_units,
- self_attention_dropout_rate,
- src_attention_dropout_rate,
- positional_dropout_rate,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
+ self,
+ size,
+ attention_heads,
+ attention_dim,
+ linear_units,
+ self_attention_dropout_rate,
+ src_attention_dropout_rate,
+ positional_dropout_rate,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
):
"""Construct an SelfSrcAttention object."""
super(SelfSrcAttention, self).__init__()
self.size = size
- self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
- self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
- self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
+ self.self_attn = MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ )
+ self.src_attn = MultiHeadedAttentionReturnWeight(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ )
+ self.feed_forward = PositionwiseFeedForward(
+ attention_dim, linear_units, positional_dropout_rate
+ )
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
@@ -319,9 +322,7 @@
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
- tgt_concat = torch.cat(
- (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
- )
+ tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
@@ -332,9 +333,7 @@
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
- x_concat = torch.cat(
- (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
- )
+ x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1)
x = residual + self.concat_linear2(x_concat)
else:
x, score = self.src_attn(x, memory, memory, memory_mask)
@@ -357,7 +356,15 @@
@tables.register("encoder_classes", "ConvBiasPredictor")
class ConvPredictor(nn.Module):
- def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
+ def __init__(
+ self,
+ size=256,
+ l_order=3,
+ r_order=3,
+ attention_heads=4,
+ attention_dropout_rate=0.1,
+ linear_units=2048,
+ ):
super().__init__()
self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
self.norm1 = LayerNorm(size)
@@ -367,17 +374,16 @@
self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
self.output_linear = nn.Linear(size, 1)
-
def forward(self, text_enc, asr_enc):
# stage1 cross-attention
residual = text_enc
text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
-
+
# stage2 FFN
residual = text_enc
text_enc = self.norm1(text_enc)
text_enc = residual + self.feed_forward(text_enc)
-
+
# stage Conv predictor
text_enc = self.norm2(text_enc)
context = text_enc.transpose(1, 2)
@@ -387,6 +393,6 @@
output = output.transpose(1, 2)
output = torch.relu(output)
output = self.output_linear(output)
- if output.dim()==3:
- output = output.squeeze(2)
+ if output.dim() == 3:
+ output = output.squeeze(2)
return output
--
Gitblit v1.9.1