From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/conformer/encoder.py |  173 +++++++++++++++++++++++++--------------------------------
 1 files changed, 77 insertions(+), 96 deletions(-)

diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py
index 443d309..7c939b4 100644
--- a/funasr/models/conformer/encoder.py
+++ b/funasr/models/conformer/encoder.py
@@ -49,6 +49,7 @@
 from funasr.register import tables
 import pdb
 
+
 class ConvolutionModule(nn.Module):
     """ConvolutionModule in Conformer model.
 
@@ -146,16 +147,16 @@
     """
 
     def __init__(
-            self,
-            size,
-            self_attn,
-            feed_forward,
-            feed_forward_macaron,
-            conv_module,
-            dropout_rate,
-            normalize_before=True,
-            concat_after=False,
-            stochastic_depth_rate=0.0,
+        self,
+        size,
+        self_attn,
+        feed_forward,
+        feed_forward_macaron,
+        conv_module,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+        stochastic_depth_rate=0.0,
     ):
         """Construct an EncoderLayer object."""
         super(EncoderLayer, self).__init__()
@@ -266,9 +267,7 @@
         residual = x
         if self.normalize_before:
             x = self.norm_ff(x)
-        x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
-            self.feed_forward(x)
-        )
+        x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(self.feed_forward(x))
         if not self.normalize_before:
             x = self.norm_ff(x)
 
@@ -321,32 +320,32 @@
     """
 
     def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            input_layer: str = "conv2d",
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            positionwise_layer_type: str = "linear",
-            positionwise_conv_kernel_size: int = 3,
-            macaron_style: bool = False,
-            rel_pos_type: str = "legacy",
-            pos_enc_layer_type: str = "rel_pos",
-            selfattention_layer_type: str = "rel_selfattn",
-            activation_type: str = "swish",
-            use_cnn_module: bool = True,
-            zero_triu: bool = False,
-            cnn_module_kernel: int = 31,
-            padding_idx: int = -1,
-            interctc_layer_idx: List[int] = [],
-            interctc_use_conditioning: bool = False,
-            stochastic_depth_rate: Union[float, List[float]] = 0.0,
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: str = "conv2d",
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 3,
+        macaron_style: bool = False,
+        rel_pos_type: str = "legacy",
+        pos_enc_layer_type: str = "rel_pos",
+        selfattention_layer_type: str = "rel_selfattn",
+        activation_type: str = "swish",
+        use_cnn_module: bool = True,
+        zero_triu: bool = False,
+        cnn_module_kernel: int = 31,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
+        stochastic_depth_rate: Union[float, List[float]] = 0.0,
     ):
         super().__init__()
         self._output_size = output_size
@@ -373,9 +372,7 @@
         elif pos_enc_layer_type == "legacy_rel_pos":
             assert selfattention_layer_type == "legacy_rel_selfattn"
             pos_enc_class = LegacyRelPositionalEncoding
-            logging.warning(
-                "Using legacy_rel_pos and it will be deprecated in the future."
-            )
+            logging.warning("Using legacy_rel_pos and it will be deprecated in the future.")
         else:
             raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
 
@@ -432,9 +429,7 @@
                 pos_enc_class(output_size, positional_dropout_rate),
             )
         elif input_layer is None:
-            self.embed = torch.nn.Sequential(
-                pos_enc_class(output_size, positional_dropout_rate)
-            )
+            self.embed = torch.nn.Sequential(pos_enc_class(output_size, positional_dropout_rate))
         else:
             raise ValueError("unknown input_layer: " + input_layer)
         self.normalize_before = normalize_before
@@ -480,9 +475,7 @@
                 output_size,
                 attention_dropout_rate,
             )
-            logging.warning(
-                "Using legacy_rel_selfattn and it will be deprecated in the future."
-            )
+            logging.warning("Using legacy_rel_selfattn and it will be deprecated in the future.")
         elif selfattention_layer_type == "rel_selfattn":
             assert pos_enc_layer_type == "rel_pos"
             encoder_selfattn_layer = RelPositionMultiHeadedAttention
@@ -534,11 +527,11 @@
         return self._output_size
 
     def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            prev_states: torch.Tensor = None,
-            ctc: CTC = None,
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Calculate forward propagation.
 
@@ -556,11 +549,11 @@
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
 
         if (
-                isinstance(self.embed, Conv2dSubsampling)
-                or isinstance(self.embed, Conv2dSubsampling2)
-                or isinstance(self.embed, Conv2dSubsampling6)
-                or isinstance(self.embed, Conv2dSubsampling8)
-                or isinstance(self.embed, Conv2dSubsamplingPad)
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling2)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
+            or isinstance(self.embed, Conv2dSubsamplingPad)
         ):
             short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
             if short_status:
@@ -573,7 +566,7 @@
             xs_pad, masks = self.embed(xs_pad, masks)
         else:
             xs_pad = self.embed(xs_pad)
-        pdb.set_trace()
+
         intermediate_outs = []
         if len(self.interctc_layer_idx) == 0:
             xs_pad, masks = self.encoders(xs_pad, masks)
@@ -601,17 +594,17 @@
                             xs_pad = (x, pos_emb)
                         else:
                             xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-        pdb.set_trace()
+
         if isinstance(xs_pad, tuple):
             xs_pad = xs_pad[0]
         if self.normalize_before:
             xs_pad = self.after_norm(xs_pad)
-        pdb.set_trace()
+
         olens = masks.squeeze(1).sum(1)
         if len(intermediate_outs) > 0:
             return (xs_pad, intermediate_outs), olens, None
         return xs_pad, olens, None
-    
+
 
 class CausalConvolution(torch.nn.Module):
     """ConformerConvolution module definition.
@@ -708,6 +701,7 @@
 
         return x, cache
 
+
 class ChunkEncoderLayer(torch.nn.Module):
     """Chunk Conformer module definition.
     Args:
@@ -797,9 +791,7 @@
         residual = x
 
         x = self.norm_macaron(x)
-        x = residual + self.feed_forward_scale * self.dropout(
-            self.feed_forward_macaron(x)
-        )
+        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward_macaron(x))
 
         residual = x
         x = self.norm_self_att(x)
@@ -876,9 +868,7 @@
 
         residual = x
         x = self.norm_conv(x)
-        x, conv_cache = self.conv_mod(
-            x, cache=self.cache[1], right_context=right_context
-        )
+        x, conv_cache = self.conv_mod(x, cache=self.cache[1], right_context=right_context)
         x = residual + x
         residual = x
 
@@ -889,6 +879,7 @@
         self.cache = [att_cache, conv_cache]
 
         return x, pos_enc
+
 
 @tables.register("encoder_classes", "ChunkConformerEncoder")
 class ConformerChunkEncoder(torch.nn.Module):
@@ -940,7 +931,6 @@
         """Construct an Encoder object."""
         super().__init__()
 
-
         self.embed = StreamingConvInput(
             input_size=input_size,
             conv_size=output_size,
@@ -954,9 +944,7 @@
             positional_dropout_rate,
         )
 
-        activation = get_activation(
-            activation_type
-       )        
+        activation = get_activation(activation_type)
 
         pos_wise_args = (
             output_size,
@@ -985,7 +973,6 @@
             simplified_att_score,
         )
 
-
         fn_modules = []
         for _ in range(num_blocks):
             module = lambda: ChunkEncoderLayer(
@@ -996,7 +983,7 @@
                 CausalConvolution(*conv_mod_args),
                 dropout_rate=dropout_rate,
             )
-            fn_modules.append(module)        
+            fn_modules.append(module)
 
         self.encoders = MultiBlocks(
             [fn() for fn in fn_modules],
@@ -1040,7 +1027,6 @@
         """
         return self.embed.get_size_before_subsampling(size)
 
-
     def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
         """Initialize/Reset encoder streaming cache.
         Args:
@@ -1062,9 +1048,7 @@
            x: Encoder outputs. (B, T_out, D_enc)
            x_len: Encoder outputs lenghts. (B,)
         """
-        short_status, limit_size = check_short_utt(
-            self.embed.subsampling_factor, x.size(1)
-        )
+        short_status, limit_size = check_short_utt(self.embed.subsampling_factor, x.size(1))
 
         if short_status:
             raise TooShortUttError(
@@ -1078,7 +1062,10 @@
 
         if self.unified_model_training:
             if self.training:
-                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+                chunk_size = (
+                    self.default_chunk_size
+                    + torch.randint(-self.jitter_range, self.jitter_range + 1, (1,)).item()
+                )
             else:
                 chunk_size = self.default_chunk_size
             x, mask = self.embed(x, mask, chunk_size)
@@ -1104,9 +1091,9 @@
 
             olens = mask.eq(0).sum(1)
             if self.time_reduction_factor > 1:
-                x_utt = x_utt[:,::self.time_reduction_factor,:]
-                x_chunk = x_chunk[:,::self.time_reduction_factor,:]
-                olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+                x_utt = x_utt[:, :: self.time_reduction_factor, :]
+                x_chunk = x_chunk[:, :: self.time_reduction_factor, :]
+                olens = torch.floor_divide(olens - 1, self.time_reduction_factor) + 1
 
             return x_utt, x_chunk, olens
 
@@ -1144,8 +1131,8 @@
 
         olens = mask.eq(0).sum(1)
         if self.time_reduction_factor > 1:
-            x = x[:,::self.time_reduction_factor,:]
-            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+            x = x[:, :: self.time_reduction_factor, :]
+            olens = torch.floor_divide(olens - 1, self.time_reduction_factor) + 1
 
         return x, olens, None
 
@@ -1162,9 +1149,7 @@
            x: Encoder outputs. (B, T_out, D_enc)
            x_len: Encoder outputs lenghts. (B,)
         """
-        short_status, limit_size = check_short_utt(
-            self.embed.subsampling_factor, x.size(1)
-        )
+        short_status, limit_size = check_short_utt(self.embed.subsampling_factor, x.size(1))
 
         if short_status:
             raise TooShortUttError(
@@ -1185,7 +1170,7 @@
         )
 
         if self.time_reduction_factor > 1:
-            x_utt = x_utt[:,::self.time_reduction_factor,:]
+            x_utt = x_utt[:, :: self.time_reduction_factor, :]
         return x_utt
 
     def simu_chunk_forward(
@@ -1196,9 +1181,7 @@
         left_context: int = 32,
         right_context: int = 0,
     ) -> torch.Tensor:
-        short_status, limit_size = check_short_utt(
-            self.embed.subsampling_factor, x.size(1)
-        )
+        short_status, limit_size = check_short_utt(self.embed.subsampling_factor, x.size(1))
 
         if short_status:
             raise TooShortUttError(
@@ -1227,7 +1210,7 @@
         )
         olens = mask.eq(0).sum(1)
         if self.time_reduction_factor > 1:
-            x = x[:,::self.time_reduction_factor,:]
+            x = x[:, :: self.time_reduction_factor, :]
 
         return x
 
@@ -1255,9 +1238,7 @@
 
         if left_context > 0:
             processed_mask = (
-                torch.arange(left_context, device=x.device)
-                .view(1, left_context)
-                .flip(1)
+                torch.arange(left_context, device=x.device).view(1, left_context).flip(1)
             )
             processed_mask = processed_mask >= processed_frames
             mask = torch.cat([processed_mask, mask], dim=1)
@@ -1275,5 +1256,5 @@
             x = x[:, 0:-right_context, :]
 
         if self.time_reduction_factor > 1:
-            x = x[:,::self.time_reduction_factor,:]
+            x = x[:, :: self.time_reduction_factor, :]
         return x

--
Gitblit v1.9.1