aky15
2023-03-21 fc9595625855be5b63f86a38ac785e49c142c0ae
funasr/models_transducer/encoder/blocks/conv_input.py
@@ -147,25 +147,26 @@
            mask = self.create_new_mask(mask)
            olens = max(mask.eq(0).sum(1))
        
        b, t_input, f = x.size()
        b, t, f = x.size()
        x = x.unsqueeze(1) # (b. 1. t. f)
        if chunk_size is not None:
            max_input_length = int(
                chunk_size * self.subsampling_factor * (math.ceil(float(t_input) / (chunk_size * self.subsampling_factor) ))
                chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
            )
            x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
            x = list(x)
            x = torch.stack(x, dim=0)
            N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
            x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
        x = self.conv(x)
        _, c, t, f = x.size()
        _, c, _, f = x.size()
        if chunk_size is not None:
            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
        else:
            x = x.transpose(1, 2).contiguous().view(b, t, c * f)
            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
        if self.output is not None:
            x = self.output(x)