zhifu gao
2024-04-23 2ac38adbe5f4e1374a079e032ed4b504351a207c
funasr/models/sense_voice/decoder.py
@@ -245,29 +245,7 @@
      self.register_buffer("mask", mask, persistent=False)
      
      self.use_padmask = kwargs.get("use_padmask", True)
   # def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
   #    """
   #    x : torch.LongTensor, shape = (batch_size, <= n_ctx)
   #       the text tokens
   #    xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
   #       the encoded audio features to be attended on
   #    """
   #    offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
   #    x = (
   #       self.token_embedding(x)
   #       + self.positional_embedding[offset: offset + x.shape[-1]]
   #    )
   #    x = x.to(xa.dtype)
   #
   #    for block in self.blocks:
   #       x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
   #
   #    x = self.ln(x)
   #    logits = (
   #       x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
   #    ).float()
   #
   #    return logits
   
   def forward(