From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/eres2net/eres2net.py |  167 +++++++++++++++++++++----------------------------------
 1 files changed, 65 insertions(+), 102 deletions(-)

diff --git a/funasr/models/eres2net/eres2net.py b/funasr/models/eres2net/eres2net.py
index 3ea9fdf..de6c2bb 100644
--- a/funasr/models/eres2net/eres2net.py
+++ b/funasr/models/eres2net/eres2net.py
@@ -26,21 +26,18 @@
         super(ReLU, self).__init__(0, 20, inplace)
 
     def __repr__(self):
-        inplace_str = 'inplace' if self.inplace else ''
-        return self.__class__.__name__ + ' (' \
-            + inplace_str + ')'
+        inplace_str = "inplace" if self.inplace else ""
+        return self.__class__.__name__ + " (" + inplace_str + ")"
 
 
 def conv1x1(in_planes, out_planes, stride=1):
     "1x1 convolution without padding"
-    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
-                     padding=0, bias=False)
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
 
 
 def conv3x3(in_planes, out_planes, stride=1):
     "3x3 convolution with padding"
-    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
-                     padding=1, bias=False)
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
 
 
 class BasicBlockERes2Net(nn.Module):
@@ -67,12 +64,11 @@
         self.shortcut = nn.Sequential()
         if stride != 1 or in_planes != self.expansion * planes:
             self.shortcut = nn.Sequential(
-                nn.Conv2d(in_planes,
-                          self.expansion * planes,
-                          kernel_size=1,
-                          stride=stride,
-                          bias=False),
-                nn.BatchNorm2d(self.expansion * planes))
+                nn.Conv2d(
+                    in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False
+                ),
+                nn.BatchNorm2d(self.expansion * planes),
+            )
         self.stride = stride
         self.width = width
         self.scale = scale
@@ -135,12 +131,11 @@
         self.shortcut = nn.Sequential()
         if stride != 1 or in_planes != self.expansion * planes:
             self.shortcut = nn.Sequential(
-                nn.Conv2d(in_planes,
-                          self.expansion * planes,
-                          kernel_size=1,
-                          stride=stride,
-                          bias=False),
-                nn.BatchNorm2d(self.expansion * planes))
+                nn.Conv2d(
+                    in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False
+                ),
+                nn.BatchNorm2d(self.expansion * planes),
+            )
         self.stride = stride
         self.width = width
         self.scale = scale
@@ -176,15 +171,17 @@
 
 
 class ERes2Net(nn.Module):
-    def __init__(self,
-                 block=BasicBlockERes2Net,
-                 block_fuse=BasicBlockERes2Net_diff_AFF,
-                 num_blocks=[3, 4, 6, 3],
-                 m_channels=32,
-                 feat_dim=80,
-                 embedding_size=192,
-                 pooling_func='TSTP',
-                 two_emb_layer=False):
+    def __init__(
+        self,
+        block=BasicBlockERes2Net,
+        block_fuse=BasicBlockERes2Net_diff_AFF,
+        num_blocks=[3, 4, 6, 3],
+        m_channels=32,
+        feat_dim=80,
+        embedding_size=192,
+        pooling_func="TSTP",
+        two_emb_layer=False,
+    ):
         super(ERes2Net, self).__init__()
         self.in_planes = m_channels
         self.feat_dim = feat_dim
@@ -192,48 +189,32 @@
         self.stats_dim = int(feat_dim / 8) * m_channels * 8
         self.two_emb_layer = two_emb_layer
 
-        self.conv1 = nn.Conv2d(1,
-                               m_channels,
-                               kernel_size=3,
-                               stride=1,
-                               padding=1,
-                               bias=False)
+        self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
         self.bn1 = nn.BatchNorm2d(m_channels)
-        self.layer1 = self._make_layer(block,
-                                       m_channels,
-                                       num_blocks[0],
-                                       stride=1)
-        self.layer2 = self._make_layer(block,
-                                       m_channels * 2,
-                                       num_blocks[1],
-                                       stride=2)
-        self.layer3 = self._make_layer(block_fuse,
-                                       m_channels * 4,
-                                       num_blocks[2],
-                                       stride=2)
-        self.layer4 = self._make_layer(block_fuse,
-                                       m_channels * 8,
-                                       num_blocks[3],
-                                       stride=2)
+        self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
+        self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
+        self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
+        self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
 
         # Downsampling module for each layer
-        self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1,
-                                           bias=False)
-        self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
-                                           bias=False)
-        self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
-                                           bias=False)
+        self.layer1_downsample = nn.Conv2d(
+            m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False
+        )
+        self.layer2_downsample = nn.Conv2d(
+            m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
+        )
+        self.layer3_downsample = nn.Conv2d(
+            m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
+        )
 
         # Bottom-up fusion module
         self.fuse_mode12 = AFF(channels=m_channels * 4)
         self.fuse_mode123 = AFF(channels=m_channels * 8)
         self.fuse_mode1234 = AFF(channels=m_channels * 16)
 
-        self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
-        self.pool = getattr(pooling_layers, pooling_func)(
-            in_dim=self.stats_dim * block.expansion)
-        self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
-                               embedding_size)
+        self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
+        self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
+        self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
         if self.two_emb_layer:
             self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
             self.seg_2 = nn.Linear(embedding_size, embedding_size)
@@ -298,12 +279,11 @@
         self.shortcut = nn.Sequential()
         if stride != 1 or in_planes != self.expansion * planes:
             self.shortcut = nn.Sequential(
-                nn.Conv2d(in_planes,
-                          self.expansion * planes,
-                          kernel_size=1,
-                          stride=stride,
-                          bias=False),
-                nn.BatchNorm2d(self.expansion * planes))
+                nn.Conv2d(
+                    in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False
+                ),
+                nn.BatchNorm2d(self.expansion * planes),
+            )
         self.stride = stride
         self.width = width
         self.scale = scale
@@ -340,14 +320,16 @@
 
 
 class Res2Net(nn.Module):
-    def __init__(self,
-                 block=BasicBlockRes2Net,
-                 num_blocks=[3, 4, 6, 3],
-                 m_channels=32,
-                 feat_dim=80,
-                 embedding_size=192,
-                 pooling_func='TSTP',
-                 two_emb_layer=False):
+    def __init__(
+        self,
+        block=BasicBlockRes2Net,
+        num_blocks=[3, 4, 6, 3],
+        m_channels=32,
+        feat_dim=80,
+        embedding_size=192,
+        pooling_func="TSTP",
+        two_emb_layer=False,
+    ):
         super(Res2Net, self).__init__()
         self.in_planes = m_channels
         self.feat_dim = feat_dim
@@ -355,35 +337,16 @@
         self.stats_dim = int(feat_dim / 8) * m_channels * 8
         self.two_emb_layer = two_emb_layer
 
-        self.conv1 = nn.Conv2d(1,
-                               m_channels,
-                               kernel_size=3,
-                               stride=1,
-                               padding=1,
-                               bias=False)
+        self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
         self.bn1 = nn.BatchNorm2d(m_channels)
-        self.layer1 = self._make_layer(block,
-                                       m_channels,
-                                       num_blocks[0],
-                                       stride=1)
-        self.layer2 = self._make_layer(block,
-                                       m_channels * 2,
-                                       num_blocks[1],
-                                       stride=2)
-        self.layer3 = self._make_layer(block,
-                                       m_channels * 4,
-                                       num_blocks[2],
-                                       stride=2)
-        self.layer4 = self._make_layer(block,
-                                       m_channels * 8,
-                                       num_blocks[3],
-                                       stride=2)
+        self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
+        self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
+        self.layer3 = self._make_layer(block, m_channels * 4, num_blocks[2], stride=2)
+        self.layer4 = self._make_layer(block, m_channels * 8, num_blocks[3], stride=2)
 
-        self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
-        self.pool = getattr(pooling_layers, pooling_func)(
-            in_dim=self.stats_dim * block.expansion)
-        self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
-                               embedding_size)
+        self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
+        self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
+        self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
         if self.two_emb_layer:
             self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
             self.seg_2 = nn.Linear(embedding_size, embedding_size)

--
Gitblit v1.9.1