zhifu gao
2024-05-06 00d0df3a1018c63ec8c5d13e611f53c564c0a7e2
funasr/models/sond/label_aggregation.py
@@ -90,9 +90,7 @@
        self.hop_length = hop_length
    def extra_repr(self):
        return (
            f"hop_length={self.hop_length}, "
        )
        return f"hop_length={self.hop_length}, "
    def forward(
        self, input: torch.Tensor, ilens: torch.Tensor = None
@@ -107,7 +105,9 @@
        """
        output = F.max_pool1d(input.transpose(1, 2), self.hop_length, self.hop_length).transpose(1, 2)
        output = F.max_pool1d(input.transpose(1, 2), self.hop_length, self.hop_length).transpose(
            1, 2
        )
        olens = ilens // self.hop_length
        return output.to(input.dtype), olens
        return output.to(input.dtype), olens