From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/sond/pooling/statistic_pooling.py |   44 ++++++++++++++++++++++----------------------
 1 files changed, 22 insertions(+), 22 deletions(-)

diff --git a/funasr/models/sond/pooling/statistic_pooling.py b/funasr/models/sond/pooling/statistic_pooling.py
index 392e333..91db239 100644
--- a/funasr/models/sond/pooling/statistic_pooling.py
+++ b/funasr/models/sond/pooling/statistic_pooling.py
@@ -7,11 +7,12 @@
 
 VAR2STD_EPSILON = 1e-12
 
+
 class StatisticPooling(torch.nn.Module):
     def __init__(self, pooling_dim: Union[int, Tuple] = 2, eps=1e-12):
         super(StatisticPooling, self).__init__()
         if isinstance(pooling_dim, int):
-            pooling_dim = (pooling_dim, )
+            pooling_dim = (pooling_dim,)
         self.pooling_dim = pooling_dim
         self.eps = eps
 
@@ -22,11 +23,13 @@
             masks = torch.ones_like(xs_pad).to(xs_pad)
         else:
             masks = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
-        mean = (torch.sum(xs_pad, dim=self.pooling_dim, keepdim=True) /
-                torch.sum(masks, dim=self.pooling_dim, keepdim=True))
+        mean = torch.sum(xs_pad, dim=self.pooling_dim, keepdim=True) / torch.sum(
+            masks, dim=self.pooling_dim, keepdim=True
+        )
         squared_difference = torch.pow(xs_pad - mean, 2.0)
-        variance = (torch.sum(squared_difference, dim=self.pooling_dim, keepdim=True) /
-                    torch.sum(masks, dim=self.pooling_dim, keepdim=True))
+        variance = torch.sum(squared_difference, dim=self.pooling_dim, keepdim=True) / torch.sum(
+            masks, dim=self.pooling_dim, keepdim=True
+        )
         for i in reversed(self.pooling_dim):
             mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
 
@@ -38,14 +41,9 @@
 
         return stat_pooling
 
-    def convert_tf2torch(self, var_dict_tf, var_dict_torch):
-        return {}
-
 
 def statistic_pooling(
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor = None,
-        pooling_dim: Tuple = (2, 3)
+    xs_pad: torch.Tensor, ilens: torch.Tensor = None, pooling_dim: Tuple = (2, 3)
 ) -> torch.Tensor:
     # xs_pad in (Batch, Channel, Time, Frequency)
 
@@ -53,11 +51,13 @@
         seq_mask = torch.ones_like(xs_pad).to(xs_pad)
     else:
         seq_mask = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
-    mean = (torch.sum(xs_pad, dim=pooling_dim, keepdim=True) /
-            torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
+    mean = torch.sum(xs_pad, dim=pooling_dim, keepdim=True) / torch.sum(
+        seq_mask, dim=pooling_dim, keepdim=True
+    )
     squared_difference = torch.pow(xs_pad - mean, 2.0)
-    variance = (torch.sum(squared_difference, dim=pooling_dim, keepdim=True) /
-                torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
+    variance = torch.sum(squared_difference, dim=pooling_dim, keepdim=True) / torch.sum(
+        seq_mask, dim=pooling_dim, keepdim=True
+    )
     for i in reversed(pooling_dim):
         mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
 
@@ -71,11 +71,11 @@
 
 
 def windowed_statistic_pooling(
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor = None,
-        pooling_dim: Tuple = (2, 3),
-        pooling_size: int = 20,
-        pooling_stride: int = 1
+    xs_pad: torch.Tensor,
+    ilens: torch.Tensor = None,
+    pooling_dim: Tuple = (2, 3),
+    pooling_size: int = 20,
+    pooling_stride: int = 1,
 ) -> Tuple[torch.Tensor, int]:
     # xs_pad in (Batch, Channel, Time, Frequency)
 
@@ -90,8 +90,8 @@
 
     for i in range(num_chunk):
         # B x C
-        st, ed = i*pooling_stride, i*pooling_stride+pooling_size
-        stat = statistic_pooling(features[:, :, st: ed], pooling_dim=pooling_dim)
+        st, ed = i * pooling_stride, i * pooling_stride + pooling_size
+        stat = statistic_pooling(features[:, :, st:ed], pooling_dim=pooling_dim)
         stat_list.append(stat.unsqueeze(2))
 
     # B x C x T

--
Gitblit v1.9.1