From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/frontends/utils/dnn_beamformer.py |   25 +++++++------------------
 1 files changed, 7 insertions(+), 18 deletions(-)

diff --git a/funasr/frontends/utils/dnn_beamformer.py b/funasr/frontends/utils/dnn_beamformer.py
index 75637d2..73b6bea 100644
--- a/funasr/frontends/utils/dnn_beamformer.py
+++ b/funasr/frontends/utils/dnn_beamformer.py
@@ -1,4 +1,5 @@
 """DNN beamformer module."""
+
 from typing import Tuple
 
 import torch
@@ -36,18 +37,14 @@
         beamformer_type="mvdr",
     ):
         super().__init__()
-        self.mask = MaskEstimator(
-            btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
-        )
+        self.mask = MaskEstimator(btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask)
         self.ref = AttentionReference(bidim, badim)
         self.ref_channel = ref_channel
 
         self.nmask = bnmask
 
         if beamformer_type != "mvdr":
-            raise ValueError(
-                "Not supporting beamformer_type={}".format(beamformer_type)
-            )
+            raise ValueError("Not supporting beamformer_type={}".format(beamformer_type))
         self.beamformer_type = beamformer_type
 
     def forward(
@@ -76,9 +73,7 @@
                 u, _ = self.ref(psd_speech, ilens)
             else:
                 # (optional) Create onehot vector for fixed reference microphone
-                u = torch.zeros(
-                    *(data.size()[:-3] + (data.size(-2),)), device=data.device
-                )
+                u = torch.zeros(*(data.size()[:-3] + (data.size(-2),)), device=data.device)
                 u[..., self.ref_channel].fill_(1)
 
             ws = get_mvdr_vector(psd_speech, psd_noise, u)
@@ -108,9 +103,7 @@
             mask_speech = list(masks[:-1])
             mask_noise = masks[-1]
 
-            psd_speeches = [
-                get_power_spectral_density_matrix(data, mask) for mask in mask_speech
-            ]
+            psd_speeches = [get_power_spectral_density_matrix(data, mask) for mask in mask_speech]
             psd_noise = get_power_spectral_density_matrix(data, mask_noise)
 
             enhanced = []
@@ -118,9 +111,7 @@
             for i in range(self.nmask - 1):
                 psd_speech = psd_speeches.pop(i)
                 # treat all other speakers' psd_speech as noises
-                enh, w = apply_beamforming(
-                    data, ilens, psd_speech, sum(psd_speeches) + psd_noise
-                )
+                enh, w = apply_beamforming(data, ilens, psd_speech, sum(psd_speeches) + psd_noise)
                 psd_speeches.insert(i, psd_speech)
 
                 # (..., F, T) -> (..., T, F)
@@ -155,9 +146,7 @@
         B, _, C = psd_in.size()[:3]
         assert psd_in.size(2) == psd_in.size(3), psd_in.size()
         # psd_in: (B, F, C, C)
-        psd = psd_in.masked_fill(
-            torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
-        )
+        psd = psd_in.masked_fill(torch.eye(C, dtype=torch.bool, device=psd_in.device), 0)
         # psd: (B, F, C, C) -> (B, C, F)
         psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
 

--
Gitblit v1.9.1