zhifu gao
2024-05-08 4adb76a6edbca93aae7caa83382e764d7b058f07
funasr/models/mossformer/mossformer_decoder.py
@@ -40,9 +40,7 @@
        """
        if x.dim() not in [2, 3]:
            raise RuntimeError(
                "{} accept 3/4D tensor as input".format(self.__name__)
            )
            raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__))
        x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
        if torch.squeeze(x).dim() == 1:
@@ -50,4 +48,3 @@
        else:
            x = torch.squeeze(x)
        return x