liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/sond/encoder/fsmn_encoder.py
@@ -18,16 +18,17 @@
class FsmnBlock(torch.nn.Module):
    def __init__(
            self,
            n_feat,
            dropout_rate,
            kernel_size,
            fsmn_shift=0,
        self,
        n_feat,
        dropout_rate,
        kernel_size,
        fsmn_shift=0,
    ):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1,
                                    padding=0, groups=n_feat, bias=False)
        self.fsmn_block = nn.Conv1d(
            n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
        )
        # padding
        left_padding = (kernel_size - 1) // 2
        if fsmn_shift > 0:
@@ -53,14 +54,7 @@
class EncoderLayer(torch.nn.Module):
    def __init__(
            self,
            in_size,
            size,
            feed_forward,
            fsmn_block,
            dropout_rate=0.0
    ):
    def __init__(self, in_size, size, feed_forward, fsmn_block, dropout_rate=0.0):
        super().__init__()
        self.in_size = in_size
        self.size = size
@@ -69,9 +63,7 @@
        self.dropout = nn.Dropout(dropout_rate)
    def forward(
            self,
            xs_pad: torch.Tensor,
            mask: torch.Tensor
        self, xs_pad: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # xs_pad in Batch, Time, Dim
@@ -86,24 +78,24 @@
class FsmnEncoder(AbsEncoder):
    """Encoder using Fsmn
      """
    """Encoder using Fsmn"""
    def __init__(self,
                 in_units,
                 filter_size,
                 fsmn_num_layers,
                 dnn_num_layers,
                 num_memory_units=512,
                 ffn_inner_dim=2048,
                 dropout_rate=0.0,
                 shift=0,
                 position_encoder=None,
                 sample_rate=1,
                 out_units=None,
                 tf2torch_tensor_name_prefix_torch="post_net",
                 tf2torch_tensor_name_prefix_tf="EAND/post_net"
                 ):
    def __init__(
        self,
        in_units,
        filter_size,
        fsmn_num_layers,
        dnn_num_layers,
        num_memory_units=512,
        ffn_inner_dim=2048,
        dropout_rate=0.0,
        shift=0,
        position_encoder=None,
        sample_rate=1,
        out_units=None,
        tf2torch_tensor_name_prefix_torch="post_net",
        tf2torch_tensor_name_prefix_tf="EAND/post_net",
    ):
        """Initializes the parameters of the encoder.
        Args:
@@ -148,14 +140,9 @@
                    ffn_inner_dim,
                    num_memory_units,
                    1,
                    dropout_rate
                ),
                FsmnBlock(
                    num_memory_units,
                    dropout_rate,
                    filter_size,
                    self.shift[lnum]
                )
                ),
                FsmnBlock(num_memory_units, dropout_rate, filter_size, self.shift[lnum]),
            ),
        )
@@ -167,7 +154,7 @@
                num_memory_units,
                1,
                dropout_rate,
            )
            ),
        )
        if out_units is not None:
            self.conv1d = nn.Conv1d(num_memory_units, out_units, 1, 1)
@@ -176,10 +163,7 @@
        return self.num_memory_units
    def forward(
            self,
            xs_pad: torch.Tensor,
            ilens: torch.Tensor,
            prev_states: torch.Tensor = None
        self, xs_pad: torch.Tensor, ilens: torch.Tensor, prev_states: torch.Tensor = None
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        inputs = xs_pad
        if self.position_encoder is not None:
@@ -194,4 +178,3 @@
            inputs = self.conv1d(inputs.transpose(1, 2)).transpose(1, 2)
        return inputs, ilens, None