From 668b830cb2a8f69c1cfb131ec9542d27f91b7283 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期三, 10 一月 2024 19:10:26 +0800
Subject: [PATCH] update cam++ for embed extract

---
 funasr/models/campplus/components.py |  112 ++++++++++++++++++++++++++++++++++++++------------------
 1 files changed, 76 insertions(+), 36 deletions(-)

diff --git a/funasr/models/campplus/layers.py b/funasr/models/campplus/components.py
similarity index 86%
rename from funasr/models/campplus/layers.py
rename to funasr/models/campplus/components.py
index 0475612..43d366e 100644
--- a/funasr/models/campplus/layers.py
+++ b/funasr/models/campplus/components.py
@@ -7,6 +7,82 @@
 from torch import nn
 
 
+class BasicResBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(BasicResBlock, self).__init__()
+        self.conv1 = nn.Conv2d(in_planes,
+                               planes,
+                               kernel_size=3,
+                               stride=(stride, 1),
+                               padding=1,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(planes,
+                               planes,
+                               kernel_size=3,
+                               stride=1,
+                               padding=1,
+                               bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes,
+                          self.expansion * planes,
+                          kernel_size=1,
+                          stride=(stride, 1),
+                          bias=False),
+                nn.BatchNorm2d(self.expansion * planes))
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.bn2(self.conv2(out))
+        out += self.shortcut(x)
+        out = F.relu(out)
+        return out
+
+
+class FCM(nn.Module):
+    def __init__(self,
+                 block=BasicResBlock,
+                 num_blocks=[2, 2],
+                 m_channels=32,
+                 feat_dim=80):
+        super(FCM, self).__init__()
+        self.in_planes = m_channels
+        self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(m_channels)
+
+        self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
+        self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
+
+        self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(m_channels)
+        self.out_channels = m_channels * (feat_dim // 8)
+
+    def _make_layer(self, block, planes, num_blocks, stride):
+        strides = [stride] + [1] * (num_blocks - 1)
+        layers = []
+        for stride in strides:
+            layers.append(block(self.in_planes, planes, stride))
+            self.in_planes = planes * block.expansion
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = x.unsqueeze(1)
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.layer1(out)
+        out = self.layer2(out)
+        out = F.relu(self.bn2(self.conv2(out)))
+
+        shape = out.shape
+        out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
+        return out
+
+
 def get_nonlinear(config_str, channels):
     nonlinear = nn.Sequential()
     for name in config_str.split('-'):
@@ -216,39 +292,3 @@
         return x
 
 
-class BasicResBlock(nn.Module):
-    expansion = 1
-
-    def __init__(self, in_planes, planes, stride=1):
-        super(BasicResBlock, self).__init__()
-        self.conv1 = nn.Conv2d(in_planes,
-                               planes,
-                               kernel_size=3,
-                               stride=(stride, 1),
-                               padding=1,
-                               bias=False)
-        self.bn1 = nn.BatchNorm2d(planes)
-        self.conv2 = nn.Conv2d(planes,
-                               planes,
-                               kernel_size=3,
-                               stride=1,
-                               padding=1,
-                               bias=False)
-        self.bn2 = nn.BatchNorm2d(planes)
-
-        self.shortcut = nn.Sequential()
-        if stride != 1 or in_planes != self.expansion * planes:
-            self.shortcut = nn.Sequential(
-                nn.Conv2d(in_planes,
-                          self.expansion * planes,
-                          kernel_size=1,
-                          stride=(stride, 1),
-                          bias=False),
-                nn.BatchNorm2d(self.expansion * planes))
-
-    def forward(self, x):
-        out = F.relu(self.bn1(self.conv1(x)))
-        out = self.bn2(self.conv2(out))
-        out += self.shortcut(x)
-        out = F.relu(out)
-        return out

--
Gitblit v1.9.1