funasr/models/decoder/sanm_decoder.py
@@ -935,6 +935,7 @@ hlens: torch.Tensor, ys_in_pad: torch.Tensor, ys_in_lens: torch.Tensor, chunk_mask: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward decoder. @@ -958,6 +959,10 @@ memory = hs_pad memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] if chunk_mask is not None: memory_mask = memory_mask * chunk_mask if tgt_mask.size(1) != memory_mask.size(1): memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1) x = tgt x, tgt_mask, memory, memory_mask, _ = self.decoders(