雾聪
2023-06-02 eee6af2ece605035b0a0835eb9dbed5ae872c755
funasr/models/frontend/default.py
@@ -101,8 +101,8 @@
        if input_stft.dim() == 4:
            # h: (B, T, C, F) -> h: (B, T, F)
            if self.training:
                if self.use_channel == None:
                    input_stft = input_stft[:, :, 0, :]
                if self.use_channel is not None:
                    input_stft = input_stft[:, :, self.use_channel, :]
                else:
                    # Select 1ch randomly
                    ch = np.random.randint(input_stft.size(2))