王梦迪
2025-05-20 fe588bc508c0076bb007d6ed36c18ac8ecb341ac
funasr/models/sond/pooling/statistic_pooling.py
@@ -7,11 +7,12 @@
VAR2STD_EPSILON = 1e-12
class StatisticPooling(torch.nn.Module):
    def __init__(self, pooling_dim: Union[int, Tuple] = 2, eps=1e-12):
        super(StatisticPooling, self).__init__()
        if isinstance(pooling_dim, int):
            pooling_dim = (pooling_dim, )
            pooling_dim = (pooling_dim,)
        self.pooling_dim = pooling_dim
        self.eps = eps
@@ -22,11 +23,13 @@
            masks = torch.ones_like(xs_pad).to(xs_pad)
        else:
            masks = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
        mean = (torch.sum(xs_pad, dim=self.pooling_dim, keepdim=True) /
                torch.sum(masks, dim=self.pooling_dim, keepdim=True))
        mean = torch.sum(xs_pad, dim=self.pooling_dim, keepdim=True) / torch.sum(
            masks, dim=self.pooling_dim, keepdim=True
        )
        squared_difference = torch.pow(xs_pad - mean, 2.0)
        variance = (torch.sum(squared_difference, dim=self.pooling_dim, keepdim=True) /
                    torch.sum(masks, dim=self.pooling_dim, keepdim=True))
        variance = torch.sum(squared_difference, dim=self.pooling_dim, keepdim=True) / torch.sum(
            masks, dim=self.pooling_dim, keepdim=True
        )
        for i in reversed(self.pooling_dim):
            mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
@@ -40,9 +43,7 @@
def statistic_pooling(
        xs_pad: torch.Tensor,
        ilens: torch.Tensor = None,
        pooling_dim: Tuple = (2, 3)
    xs_pad: torch.Tensor, ilens: torch.Tensor = None, pooling_dim: Tuple = (2, 3)
) -> torch.Tensor:
    # xs_pad in (Batch, Channel, Time, Frequency)
@@ -50,11 +51,13 @@
        seq_mask = torch.ones_like(xs_pad).to(xs_pad)
    else:
        seq_mask = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad)
    mean = (torch.sum(xs_pad, dim=pooling_dim, keepdim=True) /
            torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
    mean = torch.sum(xs_pad, dim=pooling_dim, keepdim=True) / torch.sum(
        seq_mask, dim=pooling_dim, keepdim=True
    )
    squared_difference = torch.pow(xs_pad - mean, 2.0)
    variance = (torch.sum(squared_difference, dim=pooling_dim, keepdim=True) /
                torch.sum(seq_mask, dim=pooling_dim, keepdim=True))
    variance = torch.sum(squared_difference, dim=pooling_dim, keepdim=True) / torch.sum(
        seq_mask, dim=pooling_dim, keepdim=True
    )
    for i in reversed(pooling_dim):
        mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i)
@@ -68,11 +71,11 @@
def windowed_statistic_pooling(
        xs_pad: torch.Tensor,
        ilens: torch.Tensor = None,
        pooling_dim: Tuple = (2, 3),
        pooling_size: int = 20,
        pooling_stride: int = 1
    xs_pad: torch.Tensor,
    ilens: torch.Tensor = None,
    pooling_dim: Tuple = (2, 3),
    pooling_size: int = 20,
    pooling_stride: int = 1,
) -> Tuple[torch.Tensor, int]:
    # xs_pad in (Batch, Channel, Time, Frequency)
@@ -87,8 +90,8 @@
    for i in range(num_chunk):
        # B x C
        st, ed = i*pooling_stride, i*pooling_stride+pooling_size
        stat = statistic_pooling(features[:, :, st: ed], pooling_dim=pooling_dim)
        st, ed = i * pooling_stride, i * pooling_stride + pooling_size
        stat = statistic_pooling(features[:, :, st:ed], pooling_dim=pooling_dim)
        stat_list.append(stat.unsqueeze(2))
    # B x C x T