From 7fe447185c80ca1290aa434c4dcaf0c8f0e1fa7b Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期五, 10 二月 2023 19:01:52 +0800
Subject: [PATCH] add sond model
---
funasr/modules/multi_layer_conv.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 52 insertions(+), 0 deletions(-)
diff --git a/funasr/modules/multi_layer_conv.py b/funasr/modules/multi_layer_conv.py
index 5fb0717..9d269ab 100644
--- a/funasr/modules/multi_layer_conv.py
+++ b/funasr/modules/multi_layer_conv.py
@@ -63,6 +63,58 @@
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
+class FsmnFeedForward(torch.nn.Module):
+ """Position-wise feed forward for FSMN blocks.
+
+ This is a module of multi-leyered conv1d designed
+ to replace position-wise feed-forward network
+ in FSMN block.
+ """
+
+ def __init__(self, in_chans, hidden_chans, out_chans, kernel_size, dropout_rate):
+ """Initialize FsmnFeedForward module.
+
+ Args:
+ in_chans (int): Number of input channels.
+ hidden_chans (int): Number of hidden channels.
+ out_chans (int): Number of output channels.
+ kernel_size (int): Kernel size of conv1d.
+ dropout_rate (float): Dropout rate.
+
+ """
+ super(FsmnFeedForward, self).__init__()
+ self.w_1 = torch.nn.Conv1d(
+ in_chans,
+ hidden_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.w_2 = torch.nn.Conv1d(
+ hidden_chans,
+ out_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ bias=False
+ )
+ self.norm = torch.nn.LayerNorm(hidden_chans)
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ def forward(self, x, ilens=None):
+ """Calculate forward propagation.
+
+ Args:
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
+
+ Returns:
+ torch.Tensor: Batch of output tensors (B, T, out_chans).
+
+ """
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
+ return self.w_2(self.norm(self.dropout(x)).transpose(-1, 1)).transpose(-1, 1), ilens
+
+
class Conv1dLinear(torch.nn.Module):
"""Conv1D + Linear for Transformer block.
--
Gitblit v1.9.1