From 73613cefc97bd43699d10b8d162c69b2c4544ad5 Mon Sep 17 00:00:00 2001
From: 夜雨飘零 <yeyupiaoling@foxmail.com>
Date: 星期一, 04 十二月 2023 21:41:07 +0800
Subject: [PATCH] 增加分角色语音识别对ERes2Net模型的支持。
---
funasr/utils/speaker_utils.py | 300 +++++++++++++++++++++++++++++++++++++++++++++++++
funasr/bin/asr_inference_launch.py | 30 +++-
2 files changed, 320 insertions(+), 10 deletions(-)
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index f61c085..402a911 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -48,13 +48,13 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.vad_utils import slice_padding_fbank
-from funasr.utils.speaker_utils import (check_audio_list,
- sv_preprocess,
- sv_chunk,
- CAMPPlus,
- extract_feature,
+from funasr.utils.speaker_utils import (check_audio_list,
+ sv_preprocess,
+ sv_chunk,
+ CAMPPlus,
+ extract_feature,
postprocess,
- distribute_spk)
+ distribute_spk, ERes2Net)
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.utils.cluster_backend import ClusterBackend
from funasr.utils.modelscope_utils import get_cache_dir
@@ -819,6 +819,10 @@
)
sv_model_file = asr_model_file.replace("model.pb", "campplus_cn_common.bin")
+ if not os.path.exists(sv_model_file):
+ sv_model_file = asr_model_file.replace("model.pb", "pretrained_eres2net_aug.ckpt")
+ if not os.path.exists(sv_model_file):
+ raise FileNotFoundError("sv_model_file not found: {}".format(sv_model_file))
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
@@ -944,8 +948,14 @@
##### speaker_verification #####
##################################
# load sv model
- sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
- sv_model = CAMPPlus()
+ sv_model_dict = torch.load(sv_model_file)
+ print(f'load sv model params: {sv_model_file}')
+ if os.path.basename(sv_model_file) == "campplus_cn_common.bin":
+ sv_model = CAMPPlus()
+ else:
+ sv_model = ERes2Net()
+ if ngpu > 0:
+ sv_model.cuda()
sv_model.load_state_dict(sv_model_dict)
sv_model.eval()
cb_model = ClusterBackend()
@@ -969,9 +979,11 @@
embs = []
for x in wavs:
x = extract_feature([x])
+ if ngpu > 0:
+ x = x.cuda()
embs.append(sv_model(x))
embs = torch.cat(embs)
- embeddings.append(embs.detach().numpy())
+ embeddings.append(embs.cpu().detach().numpy())
embeddings = np.concatenate(embeddings)
labels = cb_model(embeddings)
sv_output = postprocess(segments, vad_segments, labels, embeddings)
diff --git a/funasr/utils/speaker_utils.py b/funasr/utils/speaker_utils.py
index edaf58b..df3eca7 100644
--- a/funasr/utils/speaker_utils.py
+++ b/funasr/utils/speaker_utils.py
@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
""" Some implementations are adapted from https://github.com/yuyq96/D-TDNN
"""
+import math
import torch
import torch.nn.functional as F
@@ -590,4 +591,301 @@
sentence_spk = spk
d['spk'] = sentence_spk
sd_sentence_list.append(d)
- return sd_sentence_list
\ No newline at end of file
+ return sd_sentence_list
+
+
+class AFF(nn.Module):
+
+ def __init__(self, channels=64, r=4):
+ super(AFF, self).__init__()
+ inter_channels = int(channels // r)
+
+ self.local_att = nn.Sequential(
+ nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.SiLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+
+ def forward(self, x, ds_y):
+ xa = torch.cat((x, ds_y), dim=1)
+ x_att = self.local_att(xa)
+ x_att = 1.0 + torch.tanh(x_att)
+ xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
+
+ return xo
+
+
+class TSTP(nn.Module):
+ """
+ Temporal statistics pooling, concatenate mean and std, which is used in
+ x-vector
+ Comment: simple concatenation can not make full use of both statistics
+ """
+
+ def __init__(self, **kwargs):
+ super(TSTP, self).__init__()
+
+ def forward(self, x):
+ # The last dimension is the temporal axis
+ pooling_mean = x.mean(dim=-1)
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
+ pooling_mean = pooling_mean.flatten(start_dim=1)
+ pooling_std = pooling_std.flatten(start_dim=1)
+
+ stats = torch.cat((pooling_mean, pooling_std), 1)
+ return stats
+
+
+class ReLU(nn.Hardtanh):
+
+ def __init__(self, inplace=False):
+ super(ReLU, self).__init__(0, 20, inplace)
+
+ def __repr__(self):
+ inplace_str = 'inplace' if self.inplace else ''
+ return self.__class__.__name__ + ' (' \
+ + inplace_str + ')'
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ "1x1 convolution without padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
+ padding=0, bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlockERes2Net(nn.Module):
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
+ super(BasicBlockERes2Net, self).__init__()
+ width = int(math.floor(planes * (baseWidth / 64.0)))
+ self.conv1 = conv1x1(in_planes, width * scale, stride)
+ self.bn1 = nn.BatchNorm2d(width * scale)
+ self.nums = scale
+
+ convs = []
+ bns = []
+ for i in range(self.nums):
+ convs.append(conv3x3(width, width))
+ bns.append(nn.BatchNorm2d(width))
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+ self.relu = ReLU(inplace=True)
+
+ self.conv3 = conv1x1(width * scale, planes * self.expansion)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(self.expansion * planes))
+ self.stride = stride
+ self.width = width
+ self.scale = scale
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ spx = torch.split(out, self.width, 1)
+ for i in range(self.nums):
+ if i == 0:
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ sp = self.convs[i](sp)
+ sp = self.relu(self.bns[i](sp))
+ if i == 0:
+ out = sp
+ else:
+ out = torch.cat((out, sp), 1)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ residual = self.shortcut(x)
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class BasicBlockERes2Net_diff_AFF(nn.Module):
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
+ super(BasicBlockERes2Net_diff_AFF, self).__init__()
+ width = int(math.floor(planes * (baseWidth / 64.0)))
+ self.conv1 = conv1x1(in_planes, width * scale, stride)
+ self.bn1 = nn.BatchNorm2d(width * scale)
+
+ self.nums = scale
+
+ convs = []
+ fuse_models = []
+ bns = []
+ for i in range(self.nums):
+ convs.append(conv3x3(width, width))
+ bns.append(nn.BatchNorm2d(width))
+ for j in range(self.nums - 1):
+ fuse_models.append(AFF(channels=width))
+
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+ self.fuse_models = nn.ModuleList(fuse_models)
+ self.relu = ReLU(inplace=True)
+
+ self.conv3 = conv1x1(width * scale, planes * self.expansion)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ nn.BatchNorm2d(self.expansion * planes))
+ self.stride = stride
+ self.width = width
+ self.scale = scale
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ spx = torch.split(out, self.width, 1)
+ for i in range(self.nums):
+ if i == 0:
+ sp = spx[i]
+ else:
+ sp = self.fuse_models[i - 1](sp, spx[i])
+
+ sp = self.convs[i](sp)
+ sp = self.relu(self.bns[i](sp))
+ if i == 0:
+ out = sp
+ else:
+ out = torch.cat((out, sp), 1)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ residual = self.shortcut(x)
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ERes2Net(nn.Module):
+ def __init__(self,
+ block=BasicBlockERes2Net,
+ block_fuse=BasicBlockERes2Net_diff_AFF,
+ num_blocks=[3, 4, 6, 3],
+ m_channels=64,
+ feat_dim=80,
+ embedding_size=192,
+ pooling_func='TSTP',
+ two_emb_layer=False):
+ super(ERes2Net, self).__init__()
+ self.in_planes = m_channels
+ self.feat_dim = feat_dim
+ self.embedding_size = embedding_size
+ self.stats_dim = int(feat_dim / 8) * m_channels * 8
+ self.two_emb_layer = two_emb_layer
+
+ self.conv1 = nn.Conv2d(1,
+ m_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(m_channels)
+ self.layer1 = self._make_layer(block,
+ m_channels,
+ num_blocks[0],
+ stride=1)
+ self.layer2 = self._make_layer(block,
+ m_channels * 2,
+ num_blocks[1],
+ stride=2)
+ self.layer3 = self._make_layer(block_fuse,
+ m_channels * 4,
+ num_blocks[2],
+ stride=2)
+ self.layer4 = self._make_layer(block_fuse,
+ m_channels * 8,
+ num_blocks[3],
+ stride=2)
+
+ self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
+ bias=False)
+ self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
+ bias=False)
+ self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2,
+ bias=False)
+ self.fuse_mode12 = AFF(channels=m_channels * 8)
+ self.fuse_mode123 = AFF(channels=m_channels * 16)
+ self.fuse_mode1234 = AFF(channels=m_channels * 32)
+
+ self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
+ self.pool = TSTP(in_dim=self.stats_dim * block.expansion)
+ self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
+ embedding_size)
+ if self.two_emb_layer:
+ self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
+ self.seg_2 = nn.Linear(embedding_size, embedding_size)
+ else:
+ self.seg_bn_1 = nn.Identity()
+ self.seg_2 = nn.Identity()
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
+
+ x = x.unsqueeze_(1)
+ out = F.relu(self.bn1(self.conv1(x)))
+ out1 = self.layer1(out)
+ out2 = self.layer2(out1)
+ out1_downsample = self.layer1_downsample(out1)
+ fuse_out12 = self.fuse_mode12(out2, out1_downsample)
+ out3 = self.layer3(out2)
+ fuse_out12_downsample = self.layer2_downsample(fuse_out12)
+ fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
+ out4 = self.layer4(out3)
+ fuse_out123_downsample = self.layer3_downsample(fuse_out123)
+ fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
+ stats = self.pool(fuse_out1234)
+
+ embed_a = self.seg_1(stats)
+ if self.two_emb_layer:
+ out = F.relu(embed_a)
+ out = self.seg_bn_1(out)
+ embed_b = self.seg_2(out)
+ return embed_b
+ else:
+ return embed_a
--
Gitblit v1.9.1