From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/sond/encoder/fsmn_encoder.py |   79 +++++++++++++++------------------------
 1 files changed, 31 insertions(+), 48 deletions(-)

diff --git a/funasr/models/sond/encoder/fsmn_encoder.py b/funasr/models/sond/encoder/fsmn_encoder.py
index fb87ee8..9ec9912 100644
--- a/funasr/models/sond/encoder/fsmn_encoder.py
+++ b/funasr/models/sond/encoder/fsmn_encoder.py
@@ -18,16 +18,17 @@
 
 class FsmnBlock(torch.nn.Module):
     def __init__(
-            self,
-            n_feat,
-            dropout_rate,
-            kernel_size,
-            fsmn_shift=0,
+        self,
+        n_feat,
+        dropout_rate,
+        kernel_size,
+        fsmn_shift=0,
     ):
         super().__init__()
         self.dropout = nn.Dropout(p=dropout_rate)
-        self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1,
-                                    padding=0, groups=n_feat, bias=False)
+        self.fsmn_block = nn.Conv1d(
+            n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
+        )
         # padding
         left_padding = (kernel_size - 1) // 2
         if fsmn_shift > 0:
@@ -53,14 +54,7 @@
 
 
 class EncoderLayer(torch.nn.Module):
-    def __init__(
-            self,
-            in_size,
-            size,
-            feed_forward,
-            fsmn_block,
-            dropout_rate=0.0
-    ):
+    def __init__(self, in_size, size, feed_forward, fsmn_block, dropout_rate=0.0):
         super().__init__()
         self.in_size = in_size
         self.size = size
@@ -69,9 +63,7 @@
         self.dropout = nn.Dropout(dropout_rate)
 
     def forward(
-            self,
-            xs_pad: torch.Tensor,
-            mask: torch.Tensor
+        self, xs_pad: torch.Tensor, mask: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # xs_pad in Batch, Time, Dim
 
@@ -86,24 +78,24 @@
 
 
 class FsmnEncoder(AbsEncoder):
-    """Encoder using Fsmn
-      """
+    """Encoder using Fsmn"""
 
-    def __init__(self,
-                 in_units,
-                 filter_size,
-                 fsmn_num_layers,
-                 dnn_num_layers,
-                 num_memory_units=512,
-                 ffn_inner_dim=2048,
-                 dropout_rate=0.0,
-                 shift=0,
-                 position_encoder=None,
-                 sample_rate=1,
-                 out_units=None,
-                 tf2torch_tensor_name_prefix_torch="post_net",
-                 tf2torch_tensor_name_prefix_tf="EAND/post_net"
-                 ):
+    def __init__(
+        self,
+        in_units,
+        filter_size,
+        fsmn_num_layers,
+        dnn_num_layers,
+        num_memory_units=512,
+        ffn_inner_dim=2048,
+        dropout_rate=0.0,
+        shift=0,
+        position_encoder=None,
+        sample_rate=1,
+        out_units=None,
+        tf2torch_tensor_name_prefix_torch="post_net",
+        tf2torch_tensor_name_prefix_tf="EAND/post_net",
+    ):
         """Initializes the parameters of the encoder.
 
         Args:
@@ -148,14 +140,9 @@
                     ffn_inner_dim,
                     num_memory_units,
                     1,
-                    dropout_rate
-                ),
-                FsmnBlock(
-                    num_memory_units,
                     dropout_rate,
-                    filter_size,
-                    self.shift[lnum]
-                )
+                ),
+                FsmnBlock(num_memory_units, dropout_rate, filter_size, self.shift[lnum]),
             ),
         )
 
@@ -167,7 +154,7 @@
                 num_memory_units,
                 1,
                 dropout_rate,
-            )
+            ),
         )
         if out_units is not None:
             self.conv1d = nn.Conv1d(num_memory_units, out_units, 1, 1)
@@ -176,10 +163,7 @@
         return self.num_memory_units
 
     def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            prev_states: torch.Tensor = None
+        self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         inputs = xs_pad
         if self.position_encoder is not None:
@@ -194,4 +178,3 @@
             inputs = self.conv1d(inputs.transpose(1, 2)).transpose(1, 2)
 
         return inputs, ilens, None
-

--
Gitblit v1.9.1