liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/frontends/utils/dnn_beamformer.py
@@ -1,4 +1,5 @@
"""DNN beamformer module."""
from typing import Tuple
import torch
@@ -36,18 +37,14 @@
        beamformer_type="mvdr",
    ):
        super().__init__()
        self.mask = MaskEstimator(
            btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
        )
        self.mask = MaskEstimator(btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask)
        self.ref = AttentionReference(bidim, badim)
        self.ref_channel = ref_channel
        self.nmask = bnmask
        if beamformer_type != "mvdr":
            raise ValueError(
                "Not supporting beamformer_type={}".format(beamformer_type)
            )
            raise ValueError("Not supporting beamformer_type={}".format(beamformer_type))
        self.beamformer_type = beamformer_type
    def forward(
@@ -76,9 +73,7 @@
                u, _ = self.ref(psd_speech, ilens)
            else:
                # (optional) Create onehot vector for fixed reference microphone
                u = torch.zeros(
                    *(data.size()[:-3] + (data.size(-2),)), device=data.device
                )
                u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)), device=data.device)
                u[..., self.ref_channel].fill_(1)
            ws = get_mvdr_vector(psd_speech, psd_noise, u)
@@ -108,9 +103,7 @@
            mask_speech = list(masks[:-1])
            mask_noise = masks[-1]
            psd_speeches = [
                get_power_spectral_density_matrix(data, mask) for mask in mask_speech
            ]
            psd_speeches = [get_power_spectral_density_matrix(data, mask) for mask in mask_speech]
            psd_noise = get_power_spectral_density_matrix(data, mask_noise)
            enhanced = []
@@ -118,9 +111,7 @@
            for i in range(self.nmask - 1):
                psd_speech = psd_speeches.pop(i)
                # treat all other speakers' psd_speech as noises
                enh, w = apply_beamforming(
                    data, ilens, psd_speech, sum(psd_speeches) + psd_noise
                )
                enh, w = apply_beamforming(data, ilens, psd_speech, sum(psd_speeches) + psd_noise)
                psd_speeches.insert(i, psd_speech)
                # (..., F, T) -> (..., T, F)
@@ -155,9 +146,7 @@
        B, _, C = psd_in.size()[:3]
        assert psd_in.size(2) == psd_in.size(3), psd_in.size()
        # psd_in: (B, F, C, C)
        psd = psd_in.masked_fill(
            torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
        )
        psd = psd_in.masked_fill(torch.eye(C, dtype=torch.bool, device=psd_in.device), 0)
        # psd: (B, F, C, C) -> (B, C, F)
        psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)