| | |
| | | from funasr.models.transformer.utils.repeat import repeat |
| | | from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface |
| | | |
| | | from funasr.utils.register import register_class, registry_tables |
| | | from funasr.register import tables |
| | | |
| | | class DecoderLayer(nn.Module): |
| | | """Single decoder layer module. |
| | |
| | | state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] |
| | | return logp, state_list |
| | | |
| | | @register_class("decoder_classes", "TransformerDecoder") |
| | | @tables.register("decoder_classes", "TransformerDecoder") |
| | | class TransformerDecoder(BaseTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | |
| | | ) |
| | | |
| | | |
| | | @register_class("decoder_classes", "ParaformerDecoderSAN") |
| | | @tables.register("decoder_classes", "ParaformerDecoderSAN") |
| | | class ParaformerDecoderSAN(BaseTransformerDecoder): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | |
| | | else: |
| | | return x, olens |
| | | |
| | | @register_class("decoder_classes", "LightweightConvolutionTransformerDecoder") |
| | | @tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder") |
| | | class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | |
| | | ), |
| | | ) |
| | | |
| | | @register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder") |
| | | @tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder") |
| | | class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | |
| | | ) |
| | | |
| | | |
| | | @register_class("decoder_classes", "DynamicConvolutionTransformerDecoder") |
| | | @tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder") |
| | | class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | |
| | | ), |
| | | ) |
| | | |
| | | @register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder") |
| | | @tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder") |
| | | class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): |
| | | def __init__( |
| | | self, |