游雁
2023-08-30 c2e4e3c2e9be855277d9f4fa9cd0544892ff829a
funasr/layers/label_aggregation.py
@@ -1,8 +1,7 @@
import torch
from typeguard import check_argument_types
from typing import Optional
from typing import Tuple
from torch.nn import functional as F
from funasr.modules.nets_utils import make_pad_mask
@@ -13,7 +12,6 @@
        hop_length: int = 128,
        center: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        self.win_length = win_length
@@ -80,3 +78,36 @@
            olens = None
        return output.to(input.dtype), olens
class LabelAggregateMaxPooling(torch.nn.Module):
    def __init__(
        self,
        hop_length: int = 8,
    ):
        super().__init__()
        self.hop_length = hop_length
    def extra_repr(self):
        return (
            f"hop_length={self.hop_length}, "
        )
    def forward(
        self, input: torch.Tensor, ilens: torch.Tensor = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """LabelAggregate forward function.
        Args:
            input: (Batch, Nsamples, Label_dim)
            ilens: (Batch)
        Returns:
            output: (Batch, Frames, Label_dim)
        """
        output = F.max_pool1d(input.transpose(1, 2), self.hop_length, self.hop_length).transpose(1, 2)
        olens = ilens // self.hop_length
        return output.to(input.dtype), olens