funasr/models/sond/encoder/fsmn_encoder.py
@@ -36,12 +36,12 @@ right_padding = kernel_size - 1 - left_padding self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) def forward(self, inputs, mask, mask_shfit_chunk=None): def forward(self, inputs, mask, mask_shift_chunk=None): b, t, d = inputs.size() if mask is not None: mask = torch.reshape(mask, (b, -1, 1)) if mask_shfit_chunk is not None: mask = mask * mask_shfit_chunk if mask_shift_chunk is not None: mask = mask * mask_shift_chunk inputs = inputs * mask x = inputs.transpose(1, 2)