志浩
2023-08-01 4bc6db3ef88795eb570f92f9576f8bc7c56f96bc
funasr/layers/label_aggregation.py
@@ -1,7 +1,8 @@
import torch
from typing import Optional
from typing import Tuple
from typeguard import check_argument_types
from torch.nn import functional as F
from funasr.modules.nets_utils import make_pad_mask
@@ -78,3 +79,37 @@
            olens = None
        return output.to(input.dtype), olens
class LabelAggregateMaxPooling(torch.nn.Module):
    def __init__(
        self,
        hop_length: int = 8,
    ):
        assert check_argument_types()
        super().__init__()
        self.hop_length = hop_length
    def extra_repr(self):
        return (
            f"hop_length={self.hop_length}, "
        )
    def forward(
        self, input: torch.Tensor, ilens: torch.Tensor = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """LabelAggregate forward function.
        Args:
            input: (Batch, Nsamples, Label_dim)
            ilens: (Batch)
        Returns:
            output: (Batch, Frames, Label_dim)
        """
        output = F.max_pool1d(input.transpose(1, 2), self.hop_length, self.hop_length).transpose(1, 2)
        olens = ilens // self.hop_length
        return output.to(input.dtype), olens