游雁
2024-06-09 b75d1e89bb2f513a79bb07e9100ba1cd2bbcf40c
funasr/models/sense_voice/encoder.py
@@ -8,60 +8,50 @@
def sense_voice_encode_forward(
   self,
   x: torch.Tensor,
   ilens: torch.Tensor = None,
   **kwargs,
    self,
    x: torch.Tensor,
    ilens: torch.Tensor = None,
    **kwargs,
):
   use_padmask = self.use_padmask
   x = F.gelu(self.conv1(x))
   x = F.gelu(self.conv2(x))
   x = x.permute(0, 2, 1)
   n_frames = x.size(1)
   max_pos = self.positional_embedding.size(0)
   max_pos = n_frames if n_frames < max_pos else max_pos
   x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
   if ilens is not None:
      if self.downsample_rate == 4:
         olens = (
            1
            + (
               ilens
               - self.conv1.kernel_size[0]
               + 2 * self.conv1.padding[0]
            )
            // self.conv1.stride[0]
         )
      else:
         olens = ilens
      olens = (
         1
         + (
            olens
            - self.conv2.kernel_size[0]
            + 2 * self.conv2.padding[0]
         )
         // self.conv2.stride[0]
      )
      olens = torch.clamp(olens, max=max_pos)
   else:
      olens = None
   if use_padmask and olens is not None:
      padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
   else:
      padding_mask = None
   for layer, block in enumerate(self.blocks):
      x = block(x, mask=padding_mask, is_pad_mask=True)
    use_padmask = self.use_padmask
    x = F.gelu(self.conv1(x))
    x = F.gelu(self.conv2(x))
    x = x.permute(0, 2, 1)
   x = self.ln_post(x)
   if ilens is None:
      return x
   else:
      return x, olens
    n_frames = x.size(1)
    max_pos = self.positional_embedding.size(0)
    max_pos = n_frames if n_frames < max_pos else max_pos
    x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
    if ilens is not None:
        if self.downsample_rate == 4:
            olens = (
                1
                + (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0])
                // self.conv1.stride[0]
            )
        else:
            olens = ilens
        olens = (
            1
            + (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0])
            // self.conv2.stride[0]
        )
        olens = torch.clamp(olens, max=max_pos)
    else:
        olens = None
    if use_padmask and olens is not None:
        padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
    else:
        padding_mask = None
    for layer, block in enumerate(self.blocks):
        x = block(x, mask=padding_mask, is_pad_mask=True)
    x = self.ln_post(x)
    if ilens is None:
        return x
    else:
        return x, olens