| | |
| | | import torch |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | |
| | | from typeguard import check_argument_types |
| | | from torch.nn import functional as F |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | |
| | | |
| | |
| | | olens = None |
| | | |
| | | return output.to(input.dtype), olens |
| | | |
| | | |
| | | class LabelAggregateMaxPooling(torch.nn.Module): |
| | | def __init__( |
| | | self, |
| | | hop_length: int = 8, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | |
| | | self.hop_length = hop_length |
| | | |
| | | def extra_repr(self): |
| | | return ( |
| | | f"hop_length={self.hop_length}, " |
| | | ) |
| | | |
| | | def forward( |
| | | self, input: torch.Tensor, ilens: torch.Tensor = None |
| | | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | | """LabelAggregate forward function. |
| | | |
| | | Args: |
| | | input: (Batch, Nsamples, Label_dim) |
| | | ilens: (Batch) |
| | | Returns: |
| | | output: (Batch, Frames, Label_dim) |
| | | |
| | | """ |
| | | |
| | | output = F.max_pool1d(input.transpose(1, 2), self.hop_length, self.hop_length).transpose(1, 2) |
| | | olens = ilens // self.hop_length |
| | | |
| | | return output.to(input.dtype), olens |