From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/conformer/encoder.py | 173 +++++++++++++++++++++++++--------------------------------
1 files changed, 77 insertions(+), 96 deletions(-)
diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py
index 443d309..7c939b4 100644
--- a/funasr/models/conformer/encoder.py
+++ b/funasr/models/conformer/encoder.py
@@ -49,6 +49,7 @@
from funasr.register import tables
import pdb
+
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
@@ -146,16 +147,16 @@
"""
def __init__(
- self,
- size,
- self_attn,
- feed_forward,
- feed_forward_macaron,
- conv_module,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- stochastic_depth_rate=0.0,
+ self,
+ size,
+ self_attn,
+ feed_forward,
+ feed_forward_macaron,
+ conv_module,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
@@ -266,9 +267,7 @@
residual = x
if self.normalize_before:
x = self.norm_ff(x)
- x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
- self.feed_forward(x)
- )
+ x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
@@ -321,32 +320,32 @@
"""
def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: str = "conv2d",
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 3,
- macaron_style: bool = False,
- rel_pos_type: str = "legacy",
- pos_enc_layer_type: str = "rel_pos",
- selfattention_layer_type: str = "rel_selfattn",
- activation_type: str = "swish",
- use_cnn_module: bool = True,
- zero_triu: bool = False,
- cnn_module_kernel: int = 31,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- stochastic_depth_rate: Union[float, List[float]] = 0.0,
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 3,
+ macaron_style: bool = False,
+ rel_pos_type: str = "legacy",
+ pos_enc_layer_type: str = "rel_pos",
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ zero_triu: bool = False,
+ cnn_module_kernel: int = 31,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ stochastic_depth_rate: Union[float, List[float]] = 0.0,
):
super().__init__()
self._output_size = output_size
@@ -373,9 +372,7 @@
elif pos_enc_layer_type == "legacy_rel_pos":
assert selfattention_layer_type == "legacy_rel_selfattn"
pos_enc_class = LegacyRelPositionalEncoding
- logging.warning(
- "Using legacy_rel_pos and it will be deprecated in the future."
- )
+ logging.warning("Using legacy_rel_pos and it will be deprecated in the future.")
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
@@ -432,9 +429,7 @@
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
- self.embed = torch.nn.Sequential(
- pos_enc_class(output_size, positional_dropout_rate)
- )
+ self.embed = torch.nn.Sequential(pos_enc_class(output_size, positional_dropout_rate))
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
@@ -480,9 +475,7 @@
output_size,
attention_dropout_rate,
)
- logging.warning(
- "Using legacy_rel_selfattn and it will be deprecated in the future."
- )
+ logging.warning("Using legacy_rel_selfattn and it will be deprecated in the future.")
elif selfattention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
@@ -534,11 +527,11 @@
return self._output_size
def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
@@ -556,11 +549,11 @@
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
- or isinstance(self.embed, Conv2dSubsamplingPad)
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
+ or isinstance(self.embed, Conv2dSubsamplingPad)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
@@ -573,7 +566,7 @@
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
- pdb.set_trace()
+
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
xs_pad, masks = self.encoders(xs_pad, masks)
@@ -601,17 +594,17 @@
xs_pad = (x, pos_emb)
else:
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
- pdb.set_trace()
+
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
- pdb.set_trace()
+
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
-
+
class CausalConvolution(torch.nn.Module):
"""ConformerConvolution module definition.
@@ -708,6 +701,7 @@
return x, cache
+
class ChunkEncoderLayer(torch.nn.Module):
"""Chunk Conformer module definition.
Args:
@@ -797,9 +791,7 @@
residual = x
x = self.norm_macaron(x)
- x = residual + self.feed_forward_scale * self.dropout(
- self.feed_forward_macaron(x)
- )
+ x = residual + self.feed_forward_scale * self.dropout(self.feed_forward_macaron(x))
residual = x
x = self.norm_self_att(x)
@@ -876,9 +868,7 @@
residual = x
x = self.norm_conv(x)
- x, conv_cache = self.conv_mod(
- x, cache=self.cache[1], right_context=right_context
- )
+ x, conv_cache = self.conv_mod(x, cache=self.cache[1], right_context=right_context)
x = residual + x
residual = x
@@ -889,6 +879,7 @@
self.cache = [att_cache, conv_cache]
return x, pos_enc
+
@tables.register("encoder_classes", "ChunkConformerEncoder")
class ConformerChunkEncoder(torch.nn.Module):
@@ -940,7 +931,6 @@
"""Construct an Encoder object."""
super().__init__()
-
self.embed = StreamingConvInput(
input_size=input_size,
conv_size=output_size,
@@ -954,9 +944,7 @@
positional_dropout_rate,
)
- activation = get_activation(
- activation_type
- )
+ activation = get_activation(activation_type)
pos_wise_args = (
output_size,
@@ -985,7 +973,6 @@
simplified_att_score,
)
-
fn_modules = []
for _ in range(num_blocks):
module = lambda: ChunkEncoderLayer(
@@ -996,7 +983,7 @@
CausalConvolution(*conv_mod_args),
dropout_rate=dropout_rate,
)
- fn_modules.append(module)
+ fn_modules.append(module)
self.encoders = MultiBlocks(
[fn() for fn in fn_modules],
@@ -1040,7 +1027,6 @@
"""
return self.embed.get_size_before_subsampling(size)
-
def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
"""Initialize/Reset encoder streaming cache.
Args:
@@ -1062,9 +1048,7 @@
x: Encoder outputs. (B, T_out, D_enc)
x_len: Encoder outputs lenghts. (B,)
"""
- short_status, limit_size = check_short_utt(
- self.embed.subsampling_factor, x.size(1)
- )
+ short_status, limit_size = check_short_utt(self.embed.subsampling_factor, x.size(1))
if short_status:
raise TooShortUttError(
@@ -1078,7 +1062,10 @@
if self.unified_model_training:
if self.training:
- chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ chunk_size = (
+ self.default_chunk_size
+ + torch.randint(-self.jitter_range, self.jitter_range + 1, (1,)).item()
+ )
else:
chunk_size = self.default_chunk_size
x, mask = self.embed(x, mask, chunk_size)
@@ -1104,9 +1091,9 @@
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
- x_utt = x_utt[:,::self.time_reduction_factor,:]
- x_chunk = x_chunk[:,::self.time_reduction_factor,:]
- olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+ x_utt = x_utt[:, :: self.time_reduction_factor, :]
+ x_chunk = x_chunk[:, :: self.time_reduction_factor, :]
+ olens = torch.floor_divide(olens - 1, self.time_reduction_factor) + 1
return x_utt, x_chunk, olens
@@ -1144,8 +1131,8 @@
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
- x = x[:,::self.time_reduction_factor,:]
- olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+ x = x[:, :: self.time_reduction_factor, :]
+ olens = torch.floor_divide(olens - 1, self.time_reduction_factor) + 1
return x, olens, None
@@ -1162,9 +1149,7 @@
x: Encoder outputs. (B, T_out, D_enc)
x_len: Encoder outputs lenghts. (B,)
"""
- short_status, limit_size = check_short_utt(
- self.embed.subsampling_factor, x.size(1)
- )
+ short_status, limit_size = check_short_utt(self.embed.subsampling_factor, x.size(1))
if short_status:
raise TooShortUttError(
@@ -1185,7 +1170,7 @@
)
if self.time_reduction_factor > 1:
- x_utt = x_utt[:,::self.time_reduction_factor,:]
+ x_utt = x_utt[:, :: self.time_reduction_factor, :]
return x_utt
def simu_chunk_forward(
@@ -1196,9 +1181,7 @@
left_context: int = 32,
right_context: int = 0,
) -> torch.Tensor:
- short_status, limit_size = check_short_utt(
- self.embed.subsampling_factor, x.size(1)
- )
+ short_status, limit_size = check_short_utt(self.embed.subsampling_factor, x.size(1))
if short_status:
raise TooShortUttError(
@@ -1227,7 +1210,7 @@
)
olens = mask.eq(0).sum(1)
if self.time_reduction_factor > 1:
- x = x[:,::self.time_reduction_factor,:]
+ x = x[:, :: self.time_reduction_factor, :]
return x
@@ -1255,9 +1238,7 @@
if left_context > 0:
processed_mask = (
- torch.arange(left_context, device=x.device)
- .view(1, left_context)
- .flip(1)
+ torch.arange(left_context, device=x.device).view(1, left_context).flip(1)
)
processed_mask = processed_mask >= processed_frames
mask = torch.cat([processed_mask, mask], dim=1)
@@ -1275,5 +1256,5 @@
x = x[:, 0:-right_context, :]
if self.time_reduction_factor > 1:
- x = x[:,::self.time_reduction_factor,:]
+ x = x[:, :: self.time_reduction_factor, :]
return x
--
Gitblit v1.9.1