funasr/models/mossformer/mossformer_decoder.py
@@ -40,9 +40,7 @@ """ if x.dim() not in [2, 3]: raise RuntimeError( "{} accept 3/4D tensor as input".format(self.__name__) ) raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__)) x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) if torch.squeeze(x).dim() == 1: @@ -50,4 +48,3 @@ else: x = torch.squeeze(x) return x