From 668b830cb2a8f69c1cfb131ec9542d27f91b7283 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期三, 10 一月 2024 19:10:26 +0800
Subject: [PATCH] update cam++ for embed extract

---
 funasr/models/campplus/model.py |   88 +++++++++++++++++++++-----------------------
 1 files changed, 42 insertions(+), 46 deletions(-)

diff --git a/funasr/models/campplus/campplus.py b/funasr/models/campplus/model.py
similarity index 64%
rename from funasr/models/campplus/campplus.py
rename to funasr/models/campplus/model.py
index 88113ec..84938cc 100644
--- a/funasr/models/campplus/campplus.py
+++ b/funasr/models/campplus/model.py
@@ -1,54 +1,24 @@
 # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
 # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
 
+import os
+import time
+import torch
+import logging
+import numpy as np
+import torch.nn as nn
 from collections import OrderedDict
+from typing import Union, Dict, List, Tuple, Optional
 
-import torch.nn.functional as F
-from torch import nn
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+from funasr.models.campplus.components import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
+    BasicResBlock, get_nonlinear, FCM
+from funasr.models.campplus.utils import extract_feature
 
 
-from funasr.models.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
-    BasicResBlock, get_nonlinear
-
-
-class FCM(nn.Module):
-    def __init__(self,
-                 block=BasicResBlock,
-                 num_blocks=[2, 2],
-                 m_channels=32,
-                 feat_dim=80):
-        super(FCM, self).__init__()
-        self.in_planes = m_channels
-        self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
-        self.bn1 = nn.BatchNorm2d(m_channels)
-
-        self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
-        self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
-
-        self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
-        self.bn2 = nn.BatchNorm2d(m_channels)
-        self.out_channels = m_channels * (feat_dim // 8)
-
-    def _make_layer(self, block, planes, num_blocks, stride):
-        strides = [stride] + [1] * (num_blocks - 1)
-        layers = []
-        for stride in strides:
-            layers.append(block(self.in_planes, planes, stride))
-            self.in_planes = planes * block.expansion
-        return nn.Sequential(*layers)
-
-    def forward(self, x):
-        x = x.unsqueeze(1)
-        out = F.relu(self.bn1(self.conv1(x)))
-        out = self.layer1(out)
-        out = self.layer2(out)
-        out = F.relu(self.bn2(self.conv2(out)))
-
-        shape = out.shape
-        out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
-        return out
-
-
+@tables.register("model_classes", "CAMPPlus")
 class CAMPPlus(nn.Module):
     def __init__(self,
                  feat_dim=80,
@@ -58,8 +28,9 @@
                  init_channels=128,
                  config_str='batchnorm-relu',
                  memory_efficient=True,
-                 output_level='segment'):
-        super(CAMPPlus, self).__init__()
+                 output_level='segment',
+                 **kwargs,):
+        super().__init__()
 
         self.head = FCM(feat_dim=feat_dim)
         channels = self.head.out_channels
@@ -123,3 +94,28 @@
         if self.output_level == 'frame':
             x = x.transpose(1, 2)
         return x
+
+    def generate(self,
+                 data_in,
+                 data_lengths=None,
+                 key: list=None,
+                 tokenizer=None,
+                 frontend=None,
+                 **kwargs,
+                 ):
+        # extract fbank feats
+        meta_data = {}
+        time1 = time.perf_counter()
+        audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
+        time2 = time.perf_counter()
+        meta_data["load_data"] = f"{time2 - time1:0.3f}"
+        speech, speech_lengths = extract_feature(audio_sample_list)
+        time3 = time.perf_counter()
+        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+        meta_data["batch_data_time"] = np.array(speech_lengths).sum().item() / 16000.0
+        # import pdb; pdb.set_trace()
+        results = []
+        embeddings = self.forward(speech)
+        for embedding in embeddings:
+            results.append({"spk_embedding":embedding})
+        return results, meta_data
\ No newline at end of file

--
Gitblit v1.9.1