雾聪
2024-03-29 9ba0dbd98bf69c830dfcfde8f109a400cb65e4e5
funasr/models/transformer/decoder.py
@@ -26,7 +26,7 @@
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
from funasr.utils.register import register_class, registry_tables
from funasr.register import tables
class DecoderLayer(nn.Module):
    """Single decoder layer module.
@@ -147,6 +147,35 @@
        if cache is not None:
            x = torch.cat([cache, x], dim=1)
        return x, tgt_mask, memory, memory_mask
class DecoderLayerExport(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.self_attn = model.self_attn
        self.src_attn = model.src_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2
        self.norm3 = model.norm3
    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt_q = tgt
        tgt_q_mask = tgt_mask
        x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
        residual = x
        x = self.norm2(x)
        x = residual + self.src_attn(x, memory, memory, memory_mask)
        residual = x
        x = self.norm3(x)
        x = residual + self.feed_forward(x)
        return x, tgt_mask, memory, memory_mask
@@ -352,7 +381,7 @@
        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
        return logp, state_list
@register_class("decoder_classes", "TransformerDecoder")
@tables.register("decoder_classes", "TransformerDecoder")
class TransformerDecoder(BaseTransformerDecoder):
    def __init__(
            self,
@@ -401,7 +430,7 @@
        )
@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
@tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder")
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
    def __init__(
            self,
@@ -462,7 +491,7 @@
            ),
        )
@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
@tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder")
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
    def __init__(
            self,
@@ -524,7 +553,7 @@
        )
@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
@tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder")
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
    def __init__(
            self,
@@ -585,7 +614,7 @@
            ),
        )
@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
@tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder")
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
    def __init__(
            self,