From 73613cefc97bd43699d10b8d162c69b2c4544ad5 Mon Sep 17 00:00:00 2001
From: 夜雨飘零 <yeyupiaoling@foxmail.com>
Date: 星期一, 04 十二月 2023 21:41:07 +0800
Subject: [PATCH] 增加分角色语音识别对ERes2Net模型的支持。

---
 funasr/utils/speaker_utils.py      |  300 +++++++++++++++++++++++++++++++++++++++++++++++++
 funasr/bin/asr_inference_launch.py |   30 +++-
 2 files changed, 320 insertions(+), 10 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index f61c085..402a911 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -48,13 +48,13 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils.vad_utils import slice_padding_fbank
-from funasr.utils.speaker_utils import (check_audio_list, 
-                                        sv_preprocess, 
-                                        sv_chunk, 
-                                        CAMPPlus, 
-                                        extract_feature, 
+from funasr.utils.speaker_utils import (check_audio_list,
+                                        sv_preprocess,
+                                        sv_chunk,
+                                        CAMPPlus,
+                                        extract_feature,
                                         postprocess,
-                                        distribute_spk)
+                                        distribute_spk, ERes2Net)
 from funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.utils.cluster_backend import ClusterBackend
 from funasr.utils.modelscope_utils import get_cache_dir
@@ -819,6 +819,10 @@
     )
 
     sv_model_file = asr_model_file.replace("model.pb", "campplus_cn_common.bin")
+    if not os.path.exists(sv_model_file):
+        sv_model_file = asr_model_file.replace("model.pb", "pretrained_eres2net_aug.ckpt")
+        if not os.path.exists(sv_model_file):
+            raise FileNotFoundError("sv_model_file not found: {}".format(sv_model_file))
 
     if param_dict is not None:
         hotword_list_or_file = param_dict.get('hotword')
@@ -944,8 +948,14 @@
             #####  speaker_verification  #####
             ##################################
             # load sv model
-            sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
-            sv_model = CAMPPlus()
+            sv_model_dict = torch.load(sv_model_file)
+            print(f'load sv model params: {sv_model_file}')
+            if os.path.basename(sv_model_file) == "campplus_cn_common.bin":
+                sv_model = CAMPPlus()
+            else:
+                sv_model = ERes2Net()
+            if ngpu > 0:
+                sv_model.cuda()
             sv_model.load_state_dict(sv_model_dict)
             sv_model.eval()
             cb_model = ClusterBackend()
@@ -969,9 +979,11 @@
                     embs = []
                     for x in wavs:
                         x = extract_feature([x])
+                        if ngpu > 0:
+                            x = x.cuda()
                         embs.append(sv_model(x))
                     embs = torch.cat(embs)
-                    embeddings.append(embs.detach().numpy())
+                    embeddings.append(embs.cpu().detach().numpy())
                 embeddings = np.concatenate(embeddings)
                 labels = cb_model(embeddings)
                 sv_output = postprocess(segments, vad_segments, labels, embeddings)
diff --git a/funasr/utils/speaker_utils.py b/funasr/utils/speaker_utils.py
index edaf58b..df3eca7 100644
--- a/funasr/utils/speaker_utils.py
+++ b/funasr/utils/speaker_utils.py
@@ -1,6 +1,7 @@
 # Copyright (c) Alibaba, Inc. and its affiliates.
 """ Some implementations are adapted from https://github.com/yuyq96/D-TDNN
 """
+import math
 
 import torch
 import torch.nn.functional as F
@@ -590,4 +591,301 @@
                 sentence_spk = spk
         d['spk'] = sentence_spk
         sd_sentence_list.append(d)
-    return sd_sentence_list
\ No newline at end of file
+    return sd_sentence_list
+
+
+class AFF(nn.Module):
+
+    def __init__(self, channels=64, r=4):
+        super(AFF, self).__init__()
+        inter_channels = int(channels // r)
+
+        self.local_att = nn.Sequential(
+            nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
+            nn.BatchNorm2d(inter_channels),
+            nn.SiLU(inplace=True),
+            nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+            nn.BatchNorm2d(channels),
+        )
+
+    def forward(self, x, ds_y):
+        xa = torch.cat((x, ds_y), dim=1)
+        x_att = self.local_att(xa)
+        x_att = 1.0 + torch.tanh(x_att)
+        xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
+
+        return xo
+
+
+class TSTP(nn.Module):
+    """
+    Temporal statistics pooling, concatenate mean and std, which is used in
+    x-vector
+    Comment: simple concatenation can not make full use of both statistics
+    """
+
+    def __init__(self, **kwargs):
+        super(TSTP, self).__init__()
+
+    def forward(self, x):
+        # The last dimension is the temporal axis
+        pooling_mean = x.mean(dim=-1)
+        pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
+        pooling_mean = pooling_mean.flatten(start_dim=1)
+        pooling_std = pooling_std.flatten(start_dim=1)
+
+        stats = torch.cat((pooling_mean, pooling_std), 1)
+        return stats
+
+
+class ReLU(nn.Hardtanh):
+
+    def __init__(self, inplace=False):
+        super(ReLU, self).__init__(0, 20, inplace)
+
+    def __repr__(self):
+        inplace_str = 'inplace' if self.inplace else ''
+        return self.__class__.__name__ + ' (' \
+            + inplace_str + ')'
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    "1x1 convolution without padding"
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
+                     padding=0, bias=False)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    "3x3 convolution with padding"
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+
+class BasicBlockERes2Net(nn.Module):
+    expansion = 4
+
+    def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
+        super(BasicBlockERes2Net, self).__init__()
+        width = int(math.floor(planes * (baseWidth / 64.0)))
+        self.conv1 = conv1x1(in_planes, width * scale, stride)
+        self.bn1 = nn.BatchNorm2d(width * scale)
+        self.nums = scale
+
+        convs = []
+        bns = []
+        for i in range(self.nums):
+            convs.append(conv3x3(width, width))
+            bns.append(nn.BatchNorm2d(width))
+        self.convs = nn.ModuleList(convs)
+        self.bns = nn.ModuleList(bns)
+        self.relu = ReLU(inplace=True)
+
+        self.conv3 = conv1x1(width * scale, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes,
+                          self.expansion * planes,
+                          kernel_size=1,
+                          stride=stride,
+                          bias=False),
+                nn.BatchNorm2d(self.expansion * planes))
+        self.stride = stride
+        self.width = width
+        self.scale = scale
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+        spx = torch.split(out, self.width, 1)
+        for i in range(self.nums):
+            if i == 0:
+                sp = spx[i]
+            else:
+                sp = sp + spx[i]
+            sp = self.convs[i](sp)
+            sp = self.relu(self.bns[i](sp))
+            if i == 0:
+                out = sp
+            else:
+                out = torch.cat((out, sp), 1)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        residual = self.shortcut(x)
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class BasicBlockERes2Net_diff_AFF(nn.Module):
+    expansion = 4
+
+    def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
+        super(BasicBlockERes2Net_diff_AFF, self).__init__()
+        width = int(math.floor(planes * (baseWidth / 64.0)))
+        self.conv1 = conv1x1(in_planes, width * scale, stride)
+        self.bn1 = nn.BatchNorm2d(width * scale)
+
+        self.nums = scale
+
+        convs = []
+        fuse_models = []
+        bns = []
+        for i in range(self.nums):
+            convs.append(conv3x3(width, width))
+            bns.append(nn.BatchNorm2d(width))
+        for j in range(self.nums - 1):
+            fuse_models.append(AFF(channels=width))
+
+        self.convs = nn.ModuleList(convs)
+        self.bns = nn.ModuleList(bns)
+        self.fuse_models = nn.ModuleList(fuse_models)
+        self.relu = ReLU(inplace=True)
+
+        self.conv3 = conv1x1(width * scale, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(in_planes,
+                          self.expansion * planes,
+                          kernel_size=1,
+                          stride=stride,
+                          bias=False),
+                nn.BatchNorm2d(self.expansion * planes))
+        self.stride = stride
+        self.width = width
+        self.scale = scale
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+        spx = torch.split(out, self.width, 1)
+        for i in range(self.nums):
+            if i == 0:
+                sp = spx[i]
+            else:
+                sp = self.fuse_models[i - 1](sp, spx[i])
+
+            sp = self.convs[i](sp)
+            sp = self.relu(self.bns[i](sp))
+            if i == 0:
+                out = sp
+            else:
+                out = torch.cat((out, sp), 1)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        residual = self.shortcut(x)
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class ERes2Net(nn.Module):
+    def __init__(self,
+                 block=BasicBlockERes2Net,
+                 block_fuse=BasicBlockERes2Net_diff_AFF,
+                 num_blocks=[3, 4, 6, 3],
+                 m_channels=64,
+                 feat_dim=80,
+                 embedding_size=192,
+                 pooling_func='TSTP',
+                 two_emb_layer=False):
+        super(ERes2Net, self).__init__()
+        self.in_planes = m_channels
+        self.feat_dim = feat_dim
+        self.embedding_size = embedding_size
+        self.stats_dim = int(feat_dim / 8) * m_channels * 8
+        self.two_emb_layer = two_emb_layer
+
+        self.conv1 = nn.Conv2d(1,
+                               m_channels,
+                               kernel_size=3,
+                               stride=1,
+                               padding=1,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(m_channels)
+        self.layer1 = self._make_layer(block,
+                                       m_channels,
+                                       num_blocks[0],
+                                       stride=1)
+        self.layer2 = self._make_layer(block,
+                                       m_channels * 2,
+                                       num_blocks[1],
+                                       stride=2)
+        self.layer3 = self._make_layer(block_fuse,
+                                       m_channels * 4,
+                                       num_blocks[2],
+                                       stride=2)
+        self.layer4 = self._make_layer(block_fuse,
+                                       m_channels * 8,
+                                       num_blocks[3],
+                                       stride=2)
+
+        self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2,
+                                           bias=False)
+        self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2,
+                                           bias=False)
+        self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2,
+                                           bias=False)
+        self.fuse_mode12 = AFF(channels=m_channels * 8)
+        self.fuse_mode123 = AFF(channels=m_channels * 16)
+        self.fuse_mode1234 = AFF(channels=m_channels * 32)
+
+        self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
+        self.pool = TSTP(in_dim=self.stats_dim * block.expansion)
+        self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
+                               embedding_size)
+        if self.two_emb_layer:
+            self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
+            self.seg_2 = nn.Linear(embedding_size, embedding_size)
+        else:
+            self.seg_bn_1 = nn.Identity()
+            self.seg_2 = nn.Identity()
+
+    def _make_layer(self, block, planes, num_blocks, stride):
+        strides = [stride] + [1] * (num_blocks - 1)
+        layers = []
+        for stride in strides:
+            layers.append(block(self.in_planes, planes, stride))
+            self.in_planes = planes * block.expansion
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = x.permute(0, 2, 1)  # (B,T,F) => (B,F,T)
+
+        x = x.unsqueeze_(1)
+        out = F.relu(self.bn1(self.conv1(x)))
+        out1 = self.layer1(out)
+        out2 = self.layer2(out1)
+        out1_downsample = self.layer1_downsample(out1)
+        fuse_out12 = self.fuse_mode12(out2, out1_downsample)
+        out3 = self.layer3(out2)
+        fuse_out12_downsample = self.layer2_downsample(fuse_out12)
+        fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
+        out4 = self.layer4(out3)
+        fuse_out123_downsample = self.layer3_downsample(fuse_out123)
+        fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
+        stats = self.pool(fuse_out1234)
+
+        embed_a = self.seg_1(stats)
+        if self.two_emb_layer:
+            out = F.relu(embed_a)
+            out = self.seg_bn_1(out)
+            embed_b = self.seg_2(out)
+            return embed_b
+        else:
+            return embed_a

--
Gitblit v1.9.1