From 1cdb3cc28d4d89a576cc06e5cd8eb80da1f3a3aa Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 26 四月 2024 11:27:39 +0800
Subject: [PATCH] Dev gzf exp (#1665)

---
 funasr/models/transformer/encoder.py |   81 +++++++++++++++++++---------------------
 1 files changed, 39 insertions(+), 42 deletions(-)

diff --git a/funasr/models/transformer/encoder.py b/funasr/models/transformer/encoder.py
index 1f14867..a6a85ae 100644
--- a/funasr/models/transformer/encoder.py
+++ b/funasr/models/transformer/encoder.py
@@ -30,6 +30,7 @@
 
 from funasr.register import tables
 
+
 class EncoderLayer(nn.Module):
     """Encoder layer module.
 
@@ -53,14 +54,14 @@
     """
 
     def __init__(
-            self,
-            size,
-            self_attn,
-            feed_forward,
-            dropout_rate,
-            normalize_before=True,
-            concat_after=False,
-            stochastic_depth_rate=0.0,
+        self,
+        size,
+        self_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+        stochastic_depth_rate=0.0,
     ):
         """Construct an EncoderLayer object."""
         super(EncoderLayer, self).__init__()
@@ -118,9 +119,7 @@
             x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
             x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
         else:
-            x = residual + stoch_layer_coeff * self.dropout(
-                self.self_attn(x_q, x, x, mask)
-            )
+            x = residual + stoch_layer_coeff * self.dropout(self.self_attn(x_q, x, x, mask))
         if not self.normalize_before:
             x = self.norm1(x)
 
@@ -135,6 +134,7 @@
             x = torch.cat([cache, x], dim=1)
 
         return x, mask
+
 
 @tables.register("encoder_classes", "TransformerEncoder")
 class TransformerEncoder(nn.Module):
@@ -163,24 +163,24 @@
     """
 
     def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            input_layer: Optional[str] = "conv2d",
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            positionwise_layer_type: str = "linear",
-            positionwise_conv_kernel_size: int = 1,
-            padding_idx: int = -1,
-            interctc_layer_idx: List[int] = [],
-            interctc_use_conditioning: bool = False,
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
     ):
         super().__init__()
         self._output_size = output_size
@@ -243,9 +243,7 @@
             num_blocks,
             lambda lnum: EncoderLayer(
                 output_size,
-                MultiHeadedAttention(
-                    attention_heads, output_size, attention_dropout_rate
-                ),
+                MultiHeadedAttention(attention_heads, output_size, attention_dropout_rate),
                 positionwise_layer(*positionwise_layer_args),
                 dropout_rate,
                 normalize_before,
@@ -265,11 +263,11 @@
         return self._output_size
 
     def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            prev_states: torch.Tensor = None,
-            ctc: CTC = None,
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Embed positions in tensor.
 
@@ -285,10 +283,10 @@
         if self.embed is None:
             xs_pad = xs_pad
         elif (
-                isinstance(self.embed, Conv2dSubsampling)
-                or isinstance(self.embed, Conv2dSubsampling2)
-                or isinstance(self.embed, Conv2dSubsampling6)
-                or isinstance(self.embed, Conv2dSubsampling8)
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling2)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
         ):
             short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
             if short_status:
@@ -329,4 +327,3 @@
         if len(intermediate_outs) > 0:
             return (xs_pad, intermediate_outs), olens, None
         return xs_pad, olens, None
-

--
Gitblit v1.9.1