From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/frontends/utils/dnn_beamformer.py | 25 +++++++------------------
1 files changed, 7 insertions(+), 18 deletions(-)
diff --git a/funasr/frontends/utils/dnn_beamformer.py b/funasr/frontends/utils/dnn_beamformer.py
index 75637d2..73b6bea 100644
--- a/funasr/frontends/utils/dnn_beamformer.py
+++ b/funasr/frontends/utils/dnn_beamformer.py
@@ -1,4 +1,5 @@
"""DNN beamformer module."""
+
from typing import Tuple
import torch
@@ -36,18 +37,14 @@
beamformer_type="mvdr",
):
super().__init__()
- self.mask = MaskEstimator(
- btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
- )
+ self.mask = MaskEstimator(btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask)
self.ref = AttentionReference(bidim, badim)
self.ref_channel = ref_channel
self.nmask = bnmask
if beamformer_type != "mvdr":
- raise ValueError(
- "Not supporting beamformer_type={}".format(beamformer_type)
- )
+ raise ValueError("Not supporting beamformer_type={}".format(beamformer_type))
self.beamformer_type = beamformer_type
def forward(
@@ -76,9 +73,7 @@
u, _ = self.ref(psd_speech, ilens)
else:
# (optional) Create onehot vector for fixed reference microphone
- u = torch.zeros(
- *(data.size()[:-3] + (data.size(-2),)), device=data.device
- )
+ u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)), device=data.device)
u[..., self.ref_channel].fill_(1)
ws = get_mvdr_vector(psd_speech, psd_noise, u)
@@ -108,9 +103,7 @@
mask_speech = list(masks[:-1])
mask_noise = masks[-1]
- psd_speeches = [
- get_power_spectral_density_matrix(data, mask) for mask in mask_speech
- ]
+ psd_speeches = [get_power_spectral_density_matrix(data, mask) for mask in mask_speech]
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
enhanced = []
@@ -118,9 +111,7 @@
for i in range(self.nmask - 1):
psd_speech = psd_speeches.pop(i)
# treat all other speakers' psd_speech as noises
- enh, w = apply_beamforming(
- data, ilens, psd_speech, sum(psd_speeches) + psd_noise
- )
+ enh, w = apply_beamforming(data, ilens, psd_speech, sum(psd_speeches) + psd_noise)
psd_speeches.insert(i, psd_speech)
# (..., F, T) -> (..., T, F)
@@ -155,9 +146,7 @@
B, _, C = psd_in.size()[:3]
assert psd_in.size(2) == psd_in.size(3), psd_in.size()
# psd_in: (B, F, C, C)
- psd = psd_in.masked_fill(
- torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
- )
+ psd = psd_in.masked_fill(torch.eye(C, dtype=torch.bool, device=psd_in.device), 0)
# psd: (B, F, C, C) -> (B, C, F)
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
--
Gitblit v1.9.1