From 6f88a689af05dee063d52fc65806cb0addbb764b Mon Sep 17 00:00:00 2001
From: 夜雨飘零 <yeyupiaoling@foxmail.com>
Date: 星期三, 06 十二月 2023 17:02:51 +0800
Subject: [PATCH] remove never use code (#1151)

---
 funasr/utils/speaker_utils.py |   48 ------------------------------------------------
 1 files changed, 0 insertions(+), 48 deletions(-)

diff --git a/funasr/utils/speaker_utils.py b/funasr/utils/speaker_utils.py
index a3eebf9..b769b85 100644
--- a/funasr/utils/speaker_utils.py
+++ b/funasr/utils/speaker_utils.py
@@ -108,54 +108,6 @@
     return features
 
 
-class CAMLayer(nn.Module):
-
-    def __init__(self,
-                 bn_channels,
-                 out_channels,
-                 kernel_size,
-                 stride,
-                 padding,
-                 dilation,
-                 bias,
-                 reduction=2):
-        super(CAMLayer, self).__init__()
-        self.linear_local = nn.Conv1d(
-            bn_channels,
-            out_channels,
-            kernel_size,
-            stride=stride,
-            padding=padding,
-            dilation=dilation,
-            bias=bias)
-        self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
-        self.relu = nn.ReLU(inplace=True)
-        self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
-        self.sigmoid = nn.Sigmoid()
-
-    def forward(self, x):
-        y = self.linear_local(x)
-        context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
-        context = self.relu(self.linear1(context))
-        m = self.sigmoid(self.linear2(context))
-        return y * m
-
-    def seg_pooling(self, x, seg_len=100, stype='avg'):
-        if stype == 'avg':
-            seg = F.avg_pool1d(
-                x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
-        elif stype == 'max':
-            seg = F.max_pool1d(
-                x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
-        else:
-            raise ValueError('Wrong segment pooling type.')
-        shape = seg.shape
-        seg = seg.unsqueeze(-1).expand(*shape,
-                                       seg_len).reshape(*shape[:-1], -1)
-        seg = seg[..., :x.shape[-1]]
-        return seg
-
-
 def postprocess(segments: list, vad_segments: list,
                 labels: np.ndarray, embeddings: np.ndarray) -> list:
     assert len(segments) == len(labels)

--
Gitblit v1.9.1