From 33a9e08dc9b65abc3f3e18d14253f95c79e0f749 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 08 五月 2024 19:20:43 +0800
Subject: [PATCH] dynamic batch
---
funasr/models/campplus/model.py | 137 ++++++++++++++++++++++++++-------------------
1 files changed, 78 insertions(+), 59 deletions(-)
diff --git a/funasr/models/campplus/model.py b/funasr/models/campplus/model.py
index 6706c84..e3a829b 100644
--- a/funasr/models/campplus/model.py
+++ b/funasr/models/campplus/model.py
@@ -14,8 +14,15 @@
from funasr.register import tables
from funasr.models.campplus.utils import extract_feature
from funasr.utils.load_utils import load_audio_text_image_video
-from funasr.models.campplus.components import DenseLayer, StatsPool, \
- TDNNLayer, CAMDenseTDNNBlock, TransitLayer, get_nonlinear, FCM
+from funasr.models.campplus.components import (
+ DenseLayer,
+ StatsPool,
+ TDNNLayer,
+ CAMDenseTDNNBlock,
+ TransitLayer,
+ get_nonlinear,
+ FCM,
+)
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@@ -29,16 +36,18 @@
@tables.register("model_classes", "CAMPPlus")
class CAMPPlus(torch.nn.Module):
- def __init__(self,
- feat_dim=80,
- embedding_size=192,
- growth_rate=32,
- bn_size=4,
- init_channels=128,
- config_str='batchnorm-relu',
- memory_efficient=True,
- output_level='segment',
- **kwargs,):
+ def __init__(
+ self,
+ feat_dim=80,
+ embedding_size=192,
+ growth_rate=32,
+ bn_size=4,
+ init_channels=128,
+ config_str="batchnorm-relu",
+ memory_efficient=True,
+ output_level="segment",
+ **kwargs,
+ ):
super().__init__()
self.head = FCM(feat_dim=feat_dim)
@@ -46,49 +55,56 @@
self.output_level = output_level
self.xvector = torch.nn.Sequential(
- OrderedDict([
-
- ('tdnn',
- TDNNLayer(channels,
- init_channels,
- 5,
- stride=2,
- dilation=1,
- padding=-1,
- config_str=config_str)),
- ]))
+ OrderedDict(
+ [
+ (
+ "tdnn",
+ TDNNLayer(
+ channels,
+ init_channels,
+ 5,
+ stride=2,
+ dilation=1,
+ padding=-1,
+ config_str=config_str,
+ ),
+ ),
+ ]
+ )
+ )
channels = init_channels
- for i, (num_layers, kernel_size,
- dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
- block = CAMDenseTDNNBlock(num_layers=num_layers,
- in_channels=channels,
- out_channels=growth_rate,
- bn_channels=bn_size * growth_rate,
- kernel_size=kernel_size,
- dilation=dilation,
- config_str=config_str,
- memory_efficient=memory_efficient)
- self.xvector.add_module('block%d' % (i + 1), block)
+ for i, (num_layers, kernel_size, dilation) in enumerate(
+ zip((12, 24, 16), (3, 3, 3), (1, 2, 2))
+ ):
+ block = CAMDenseTDNNBlock(
+ num_layers=num_layers,
+ in_channels=channels,
+ out_channels=growth_rate,
+ bn_channels=bn_size * growth_rate,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ config_str=config_str,
+ memory_efficient=memory_efficient,
+ )
+ self.xvector.add_module("block%d" % (i + 1), block)
channels = channels + num_layers * growth_rate
self.xvector.add_module(
- 'transit%d' % (i + 1),
- TransitLayer(channels,
- channels // 2,
- bias=False,
- config_str=config_str))
+ "transit%d" % (i + 1),
+ TransitLayer(channels, channels // 2, bias=False, config_str=config_str),
+ )
channels //= 2
- self.xvector.add_module(
- 'out_nonlinear', get_nonlinear(config_str, channels))
+ self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels))
- if self.output_level == 'segment':
- self.xvector.add_module('stats', StatsPool())
+ if self.output_level == "segment":
+ self.xvector.add_module("stats", StatsPool())
self.xvector.add_module(
- 'dense',
- DenseLayer(
- channels * 2, embedding_size, config_str='batchnorm_'))
+ "dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_")
+ )
else:
- assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
+ assert (
+ self.output_level == "frame"
+ ), "`output_level` should be set to 'segment' or 'frame'. "
for m in self.modules():
if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
@@ -100,22 +116,25 @@
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = self.head(x)
x = self.xvector(x)
- if self.output_level == 'frame':
+ if self.output_level == "frame":
x = x.transpose(1, 2)
return x
- def inference(self,
- data_in,
- data_lengths=None,
- key: list=None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
# extract fbank feats
meta_data = {}
time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
+ audio_sample_list = load_audio_text_image_video(
+ data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound"
+ )
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
@@ -123,5 +142,5 @@
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0
- results = [{"spk_embedding": self.forward(speech)}]
- return results, meta_data
\ No newline at end of file
+ results = [{"spk_embedding": self.forward(speech.to(torch.float32))}]
+ return results, meta_data
--
Gitblit v1.9.1