zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/sense_voice/encoder.py
@@ -23,27 +23,18 @@
   max_pos = n_frames if n_frames < max_pos else max_pos
   x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
   
   if ilens is not None:
      if self.downsample_rate == 4:
         olens = (
            1
            + (
               ilens
               - self.conv1.kernel_size[0]
               + 2 * self.conv1.padding[0]
            )
                + (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0])
            // self.conv1.stride[0]
         )
      else:
         olens = ilens
      olens = (
         1
         + (
            olens
            - self.conv2.kernel_size[0]
            + 2 * self.conv2.padding[0]
         )
            + (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0])
         // self.conv2.stride[0]
      )
      olens = torch.clamp(olens, max=max_pos)
@@ -57,7 +48,6 @@
   
   for layer, block in enumerate(self.blocks):
      x = block(x, mask=padding_mask, is_pad_mask=True)
   x = self.ln_post(x)