游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
funasr/models/campplus/model.py
@@ -1,25 +1,34 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker)
import os
import time
import torch
import logging
import numpy as np
import torch.nn as nn
from collections import OrderedDict
from typing import Union, Dict, List, Tuple, Optional
from contextlib import contextmanager
from distutils.version import LooseVersion
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
from funasr.models.campplus.components import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
    BasicResBlock, get_nonlinear, FCM
from funasr.models.campplus.utils import extract_feature
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.models.campplus.components import DenseLayer, StatsPool, \
    TDNNLayer, CAMDenseTDNNBlock, TransitLayer, get_nonlinear, FCM
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
@tables.register("model_classes", "CAMPPlus")
class CAMPPlus(nn.Module):
class CAMPPlus(torch.nn.Module):
    def __init__(self,
                 feat_dim=80,
                 embedding_size=192,
@@ -36,7 +45,7 @@
        channels = self.head.out_channels
        self.output_level = output_level
        self.xvector = nn.Sequential(
        self.xvector = torch.nn.Sequential(
            OrderedDict([
                ('tdnn',
@@ -82,10 +91,10 @@
            assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight.data)
            if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)):
                torch.nn.init.kaiming_normal_(m.weight.data)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
                    torch.nn.init.zeros_(m.bias)
    def forward(self, x):
        x = x.permute(0, 2, 1)  # (B,T,F) => (B,F,T)
@@ -95,7 +104,7 @@
            x = x.transpose(1, 2)
        return x
    def generate(self,
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list=None,
@@ -114,5 +123,5 @@
        time3 = time.perf_counter()
        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
        meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0
        results = [{"spk_embedding": self.forward(speech)}]
        results = [{"spk_embedding": self.forward(speech.to(torch.float32))}]
        return results, meta_data