From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/sond/pooling/statistic_pooling.py | 44 ++++++++++++++++++++++----------------------
1 files changed, 22 insertions(+), 22 deletions(-)
diff --git a/funasr/models/sond/pooling/statistic_pooling.py b/funasr/models/sond/pooling/statistic_pooling.py
index 392e333..91db239 100644
--- a/funasr/models/sond/pooling/statistic_pooling.py
+++ b/funasr/models/sond/pooling/statistic_pooling.py
@@ -7,11 +7,12 @@
VAR2STD_EPSILON = 1e-12
+
class StatisticPooling(torch.nn.Module):
def __init__(self, pooling_dim: Union[int, Tuple] = 2, eps=1e-12):
super(StatisticPooling, self).__init__()
if isinstance(pooling_dim, int):
- pooling_dim = (pooling_dim, )
+ pooling_dim = (pooling_dim,)
self.pooling_dim = pooling_dim
self.eps = eps
@@ -22,11 +23,13 @@
masks = torch.ones_like(xs_pad).to(xs_pad)
else:
masks = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
- mean = (torch.sum(xs_pad, dim=self.pooling_dim, keepdim=True) /
- torch.sum(masks, dim=self.pooling_dim, keepdim=True))
+ mean = torch.sum(xs_pad, dim=self.pooling_dim, keepdim=True) / torch.sum(
+ masks, dim=self.pooling_dim, keepdim=True
+ )
squared_difference = torch.pow(xs_pad - mean, 2.0)
- variance = (torch.sum(squared_difference, dim=self.pooling_dim, keepdim=True) /
- torch.sum(masks, dim=self.pooling_dim, keepdim=True))
+ variance = torch.sum(squared_difference, dim=self.pooling_dim, keepdim=True) / torch.sum(
+ masks, dim=self.pooling_dim, keepdim=True
+ )
for i in reversed(self.pooling_dim):
mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
@@ -38,14 +41,9 @@
return stat_pooling
- def convert_tf2torch(self, var_dict_tf, var_dict_torch):
- return {}
-
def statistic_pooling(
- xs_pad: torch.Tensor,
- ilens: torch.Tensor = None,
- pooling_dim: Tuple = (2, 3)
+ xs_pad: torch.Tensor, ilens: torch.Tensor = None, pooling_dim: Tuple = (2, 3)
) -> torch.Tensor:
# xs_pad in (Batch, Channel, Time, Frequency)
@@ -53,11 +51,13 @@
seq_mask = torch.ones_like(xs_pad).to(xs_pad)
else:
seq_mask = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
- mean = (torch.sum(xs_pad, dim=pooling_dim, keepdim=True) /
- torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
+ mean = torch.sum(xs_pad, dim=pooling_dim, keepdim=True) / torch.sum(
+ seq_mask, dim=pooling_dim, keepdim=True
+ )
squared_difference = torch.pow(xs_pad - mean, 2.0)
- variance = (torch.sum(squared_difference, dim=pooling_dim, keepdim=True) /
- torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
+ variance = torch.sum(squared_difference, dim=pooling_dim, keepdim=True) / torch.sum(
+ seq_mask, dim=pooling_dim, keepdim=True
+ )
for i in reversed(pooling_dim):
mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
@@ -71,11 +71,11 @@
def windowed_statistic_pooling(
- xs_pad: torch.Tensor,
- ilens: torch.Tensor = None,
- pooling_dim: Tuple = (2, 3),
- pooling_size: int = 20,
- pooling_stride: int = 1
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor = None,
+ pooling_dim: Tuple = (2, 3),
+ pooling_size: int = 20,
+ pooling_stride: int = 1,
) -> Tuple[torch.Tensor, int]:
# xs_pad in (Batch, Channel, Time, Frequency)
@@ -90,8 +90,8 @@
for i in range(num_chunk):
# B x C
- st, ed = i*pooling_stride, i*pooling_stride+pooling_size
- stat = statistic_pooling(features[:, :, st: ed], pooling_dim=pooling_dim)
+ st, ed = i * pooling_stride, i * pooling_stride + pooling_size
+ stat = statistic_pooling(features[:, :, st:ed], pooling_dim=pooling_dim)
stat_list.append(stat.unsqueeze(2))
# B x C x T
--
Gitblit v1.9.1