| | |
| | | """DNN beamformer module.""" |
| | | |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | |
| | | beamformer_type="mvdr", |
| | | ): |
| | | super().__init__() |
| | | self.mask = MaskEstimator( |
| | | btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask |
| | | ) |
| | | self.mask = MaskEstimator(btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask) |
| | | self.ref = AttentionReference(bidim, badim) |
| | | self.ref_channel = ref_channel |
| | | |
| | | self.nmask = bnmask |
| | | |
| | | if beamformer_type != "mvdr": |
| | | raise ValueError( |
| | | "Not supporting beamformer_type={}".format(beamformer_type) |
| | | ) |
| | | raise ValueError("Not supporting beamformer_type={}".format(beamformer_type)) |
| | | self.beamformer_type = beamformer_type |
| | | |
| | | def forward( |
| | |
| | | u, _ = self.ref(psd_speech, ilens) |
| | | else: |
| | | # (optional) Create onehot vector for fixed reference microphone |
| | | u = torch.zeros( |
| | | *(data.size()[:-3] + (data.size(-2),)), device=data.device |
| | | ) |
| | | u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)), device=data.device) |
| | | u[..., self.ref_channel].fill_(1) |
| | | |
| | | ws = get_mvdr_vector(psd_speech, psd_noise, u) |
| | |
| | | mask_speech = list(masks[:-1]) |
| | | mask_noise = masks[-1] |
| | | |
| | | psd_speeches = [ |
| | | get_power_spectral_density_matrix(data, mask) for mask in mask_speech |
| | | ] |
| | | psd_speeches = [get_power_spectral_density_matrix(data, mask) for mask in mask_speech] |
| | | psd_noise = get_power_spectral_density_matrix(data, mask_noise) |
| | | |
| | | enhanced = [] |
| | |
| | | for i in range(self.nmask - 1): |
| | | psd_speech = psd_speeches.pop(i) |
| | | # treat all other speakers' psd_speech as noises |
| | | enh, w = apply_beamforming( |
| | | data, ilens, psd_speech, sum(psd_speeches) + psd_noise |
| | | ) |
| | | enh, w = apply_beamforming(data, ilens, psd_speech, sum(psd_speeches) + psd_noise) |
| | | psd_speeches.insert(i, psd_speech) |
| | | |
| | | # (..., F, T) -> (..., T, F) |
| | |
| | | B, _, C = psd_in.size()[:3] |
| | | assert psd_in.size(2) == psd_in.size(3), psd_in.size() |
| | | # psd_in: (B, F, C, C) |
| | | psd = psd_in.masked_fill( |
| | | torch.eye(C, dtype=torch.bool, device=psd_in.device), 0 |
| | | ) |
| | | psd = psd_in.masked_fill(torch.eye(C, dtype=torch.bool, device=psd_in.device), 0) |
| | | # psd: (B, F, C, C) -> (B, C, F) |
| | | psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2) |
| | | |