From bc723ea200144bd6fa8a5dff4b9a780feda144fc Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 29 六月 2023 18:55:01 +0800
Subject: [PATCH] dcos
---
funasr/models/decoder/sanm_decoder.py | 9 +++------
1 files changed, 3 insertions(+), 6 deletions(-)
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index 508eb73..d83f89f 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -7,7 +7,6 @@
from funasr.modules.streaming_utils import utils as myutils
from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
-from typeguard import check_argument_types
from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.modules.embedding import PositionalEncoding
@@ -181,7 +180,6 @@
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
embed_tensor_name_prefix_tf: str = None,
):
- assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
@@ -838,7 +836,6 @@
tf2torch_tensor_name_prefix_torch: str = "decoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
):
- assert check_argument_types()
super().__init__(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
@@ -956,13 +953,13 @@
"""
tgt = ys_in_pad
tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
if chunk_mask is not None:
memory_mask = memory_mask * chunk_mask
if tgt_mask.size(1) != memory_mask.size(1):
memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
-
- memory = hs_pad
- memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
x = tgt
x, tgt_mask, memory, memory_mask, _ = self.decoders(
--
Gitblit v1.9.1