游雁
2023-12-21 a1b0cd33d50cee3e4612d1e787399e508b453a4a
funasr/models/sa_asr/transformer_decoder.py
@@ -27,7 +27,7 @@
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
from funasr.utils.register import register_class, registry_tables
from funasr.register import tables
class DecoderLayer(nn.Module):
    """Single decoder layer module.
@@ -353,7 +353,7 @@
        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
        return logp, state_list
@register_class("decoder_classes", "TransformerDecoder")
@tables.register("decoder_classes", "TransformerDecoder")
class TransformerDecoder(BaseTransformerDecoder):
    def __init__(
            self,
@@ -402,7 +402,7 @@
        )
@register_class("decoder_classes", "ParaformerDecoderSAN")
@tables.register("decoder_classes", "ParaformerDecoderSAN")
class ParaformerDecoderSAN(BaseTransformerDecoder):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -516,7 +516,7 @@
        else:
            return x, olens
@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
@tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder")
class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
    def __init__(
            self,
@@ -577,7 +577,7 @@
            ),
        )
@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
@tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder")
class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
    def __init__(
            self,
@@ -639,7 +639,7 @@
        )
@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
@tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder")
class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
    def __init__(
            self,
@@ -700,7 +700,7 @@
            ),
        )
@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
@tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder")
class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
    def __init__(
            self,