From 668b830cb2a8f69c1cfb131ec9542d27f91b7283 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期三, 10 一月 2024 19:10:26 +0800
Subject: [PATCH] update cam++ for embed extract
---
funasr/models/campplus/model.py | 88 +++++++++++++++++++++-----------------------
1 files changed, 42 insertions(+), 46 deletions(-)
diff --git a/funasr/models/campplus/campplus.py b/funasr/models/campplus/model.py
similarity index 64%
rename from funasr/models/campplus/campplus.py
rename to funasr/models/campplus/model.py
index 88113ec..84938cc 100644
--- a/funasr/models/campplus/campplus.py
+++ b/funasr/models/campplus/model.py
@@ -1,54 +1,24 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+import os
+import time
+import torch
+import logging
+import numpy as np
+import torch.nn as nn
from collections import OrderedDict
+from typing import Union, Dict, List, Tuple, Optional
-import torch.nn.functional as F
-from torch import nn
+from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+from funasr.models.campplus.components import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
+ BasicResBlock, get_nonlinear, FCM
+from funasr.models.campplus.utils import extract_feature
-from funasr.models.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
- BasicResBlock, get_nonlinear
-
-
-class FCM(nn.Module):
- def __init__(self,
- block=BasicResBlock,
- num_blocks=[2, 2],
- m_channels=32,
- feat_dim=80):
- super(FCM, self).__init__()
- self.in_planes = m_channels
- self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
- self.bn1 = nn.BatchNorm2d(m_channels)
-
- self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
- self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
-
- self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
- self.bn2 = nn.BatchNorm2d(m_channels)
- self.out_channels = m_channels * (feat_dim // 8)
-
- def _make_layer(self, block, planes, num_blocks, stride):
- strides = [stride] + [1] * (num_blocks - 1)
- layers = []
- for stride in strides:
- layers.append(block(self.in_planes, planes, stride))
- self.in_planes = planes * block.expansion
- return nn.Sequential(*layers)
-
- def forward(self, x):
- x = x.unsqueeze(1)
- out = F.relu(self.bn1(self.conv1(x)))
- out = self.layer1(out)
- out = self.layer2(out)
- out = F.relu(self.bn2(self.conv2(out)))
-
- shape = out.shape
- out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
- return out
-
-
+@tables.register("model_classes", "CAMPPlus")
class CAMPPlus(nn.Module):
def __init__(self,
feat_dim=80,
@@ -58,8 +28,9 @@
init_channels=128,
config_str='batchnorm-relu',
memory_efficient=True,
- output_level='segment'):
- super(CAMPPlus, self).__init__()
+ output_level='segment',
+ **kwargs,):
+ super().__init__()
self.head = FCM(feat_dim=feat_dim)
channels = self.head.out_channels
@@ -123,3 +94,28 @@
if self.output_level == 'frame':
x = x.transpose(1, 2)
return x
+
+ def generate(self,
+ data_in,
+ data_lengths=None,
+ key: list=None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+ # extract fbank feats
+ meta_data = {}
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_feature(audio_sample_list)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = np.array(speech_lengths).sum().item() / 16000.0
+ # import pdb; pdb.set_trace()
+ results = []
+ embeddings = self.forward(speech)
+ for embedding in embeddings:
+ results.append({"spk_embedding":embedding})
+ return results, meta_data
\ No newline at end of file
--
Gitblit v1.9.1