wuhongsheng
2024-07-05 3a4281f4959534b1bf5d01acf0085f4f8e6f2ec8
funasr/models/campplus/model.py
@@ -14,8 +14,15 @@
from funasr.register import tables
from funasr.models.campplus.utils import extract_feature
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.models.campplus.components import DenseLayer, StatsPool, \
    TDNNLayer, CAMDenseTDNNBlock, TransitLayer, get_nonlinear, FCM
from funasr.models.campplus.components import (
    DenseLayer,
    StatsPool,
    TDNNLayer,
    CAMDenseTDNNBlock,
    TransitLayer,
    get_nonlinear,
    FCM,
)
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@@ -29,16 +36,18 @@
@tables.register("model_classes", "CAMPPlus")
class CAMPPlus(torch.nn.Module):
    def __init__(self,
                 feat_dim=80,
                 embedding_size=192,
                 growth_rate=32,
                 bn_size=4,
                 init_channels=128,
                 config_str='batchnorm-relu',
                 memory_efficient=True,
                 output_level='segment',
                 **kwargs,):
    def __init__(
        self,
        feat_dim=80,
        embedding_size=192,
        growth_rate=32,
        bn_size=4,
        init_channels=128,
        config_str="batchnorm-relu",
        memory_efficient=True,
        output_level="segment",
        **kwargs,
    ):
        super().__init__()
        self.head = FCM(feat_dim=feat_dim)
@@ -46,49 +55,56 @@
        self.output_level = output_level
        self.xvector = torch.nn.Sequential(
            OrderedDict([
                ('tdnn',
                 TDNNLayer(channels,
                           init_channels,
                           5,
                           stride=2,
                           dilation=1,
                           padding=-1,
                           config_str=config_str)),
            ]))
            OrderedDict(
                [
                    (
                        "tdnn",
                        TDNNLayer(
                            channels,
                            init_channels,
                            5,
                            stride=2,
                            dilation=1,
                            padding=-1,
                            config_str=config_str,
                        ),
                    ),
                ]
            )
        )
        channels = init_channels
        for i, (num_layers, kernel_size,
                dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
            block = CAMDenseTDNNBlock(num_layers=num_layers,
                                      in_channels=channels,
                                      out_channels=growth_rate,
                                      bn_channels=bn_size * growth_rate,
                                      kernel_size=kernel_size,
                                      dilation=dilation,
                                      config_str=config_str,
                                      memory_efficient=memory_efficient)
            self.xvector.add_module('block%d' % (i + 1), block)
        for i, (num_layers, kernel_size, dilation) in enumerate(
            zip((12, 24, 16), (3, 3, 3), (1, 2, 2))
        ):
            block = CAMDenseTDNNBlock(
                num_layers=num_layers,
                in_channels=channels,
                out_channels=growth_rate,
                bn_channels=bn_size * growth_rate,
                kernel_size=kernel_size,
                dilation=dilation,
                config_str=config_str,
                memory_efficient=memory_efficient,
            )
            self.xvector.add_module("block%d" % (i + 1), block)
            channels = channels + num_layers * growth_rate
            self.xvector.add_module(
                'transit%d' % (i + 1),
                TransitLayer(channels,
                             channels // 2,
                             bias=False,
                             config_str=config_str))
                "transit%d" % (i + 1),
                TransitLayer(channels, channels // 2, bias=False, config_str=config_str),
            )
            channels //= 2
        self.xvector.add_module(
            'out_nonlinear', get_nonlinear(config_str, channels))
        self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels))
        if self.output_level == 'segment':
            self.xvector.add_module('stats', StatsPool())
        if self.output_level == "segment":
            self.xvector.add_module("stats", StatsPool())
            self.xvector.add_module(
                'dense',
                DenseLayer(
                    channels * 2, embedding_size, config_str='batchnorm_'))
                "dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_")
            )
        else:
            assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
            assert (
                self.output_level == "frame"
            ), "`output_level` should be set to 'segment' or 'frame'. "
        for m in self.modules():
            if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
@@ -100,22 +116,25 @@
        x = x.permute(0, 2, 1)  # (B,T,F) => (B,F,T)
        x = self.head(x)
        x = self.xvector(x)
        if self.output_level == 'frame':
        if self.output_level == "frame":
            x = x.transpose(1, 2)
        return x
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list=None,
                 tokenizer=None,
                 frontend=None,
                 **kwargs,
                 ):
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        # extract fbank feats
        meta_data = {}
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
        audio_sample_list = load_audio_text_image_video(
            data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound"
        )
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
@@ -123,5 +142,5 @@
        time3 = time.perf_counter()
        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
        meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0
        results = [{"spk_embedding": self.forward(speech)}]
        return results, meta_data
        results = [{"spk_embedding": self.forward(speech.to(torch.float32))}]
        return results, meta_data