From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/sond/encoder/self_attention_encoder.py | 57 +++++++++++++++++++++++++++++++--------------------------
1 files changed, 31 insertions(+), 26 deletions(-)
diff --git a/funasr/models/sond/encoder/self_attention_encoder.py b/funasr/models/sond/encoder/self_attention_encoder.py
index f3c4736..2e979b1 100644
--- a/funasr/models/sond/encoder/self_attention_encoder.py
+++ b/funasr/models/sond/encoder/self_attention_encoder.py
@@ -87,7 +87,9 @@
x = self.norm1(x)
if self.concat_after:
- x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+ x_concat = torch.cat(
+ (x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1
+ )
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
@@ -207,32 +209,36 @@
self.encoders = repeat(
num_blocks,
- lambda lnum: EncoderLayer(
- output_size,
- output_size,
- MultiHeadSelfAttention(
- attention_heads,
+ lambda lnum: (
+ EncoderLayer(
output_size,
output_size,
- attention_dropout_rate,
- ),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ) if lnum > 0 else EncoderLayer(
- input_size,
- output_size,
- MultiHeadSelfAttention(
- attention_heads,
- input_size if input_layer == "pe" or input_layer == "null" else output_size,
+ MultiHeadSelfAttention(
+ attention_heads,
+ output_size,
+ output_size,
+ attention_dropout_rate,
+ ),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ )
+ if lnum > 0
+ else EncoderLayer(
+ input_size,
output_size,
- attention_dropout_rate,
- ),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
+ MultiHeadSelfAttention(
+ attention_heads,
+ input_size if input_layer == "pe" or input_layer == "null" else output_size,
+ output_size,
+ attention_dropout_rate,
+ ),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ )
),
)
if self.normalize_before:
@@ -270,7 +276,7 @@
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- xs_pad = xs_pad * self.output_size()**0.5
+ xs_pad = xs_pad * self.output_size() ** 0.5
if self.embed is None:
xs_pad = xs_pad
elif (
@@ -325,4 +331,3 @@
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
-
--
Gitblit v1.9.1