zhifu gao
2024-12-25 3f8294b9d7deaa0cbdb0b2ef6f3802d46ae133a9
funasr/models/sond/encoder/fsmn_encoder.py
@@ -36,12 +36,12 @@
        right_padding = kernel_size - 1 - left_padding
        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
    def forward(self, inputs, mask, mask_shift_chunk=None):
    def forward(self, inputs, mask, mask_shfit_chunk=None):
        b, t, d = inputs.size()
        if mask is not None:
            mask = torch.reshape(mask, (b, -1, 1))
            if mask_shift_chunk is not None:
                mask = mask * mask_shift_chunk
            if mask_shfit_chunk is not None:
                mask = mask * mask_shfit_chunk
        inputs = inputs * mask
        x = inputs.transpose(1, 2)