zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/paraformer/decoder.py
@@ -17,7 +17,10 @@
from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.models.sanm.attention import (
    MultiHeadedAttentionSANMDecoder,
    MultiHeadedAttentionCrossAtt,
)
class DecoderLayerSANM(torch.nn.Module):
@@ -173,10 +176,11 @@
            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
        return x, tgt_mask, memory, memory_mask, cache
    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
    def forward_chunk(
        self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
    ):
        """Compute decoded features.
        Args:
@@ -224,6 +228,7 @@
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(
        self,
        vocab_size: int,
@@ -303,7 +308,13 @@
                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                ),
                MultiHeadedAttentionCrossAtt(
                    attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
                    attention_heads,
                    attention_dim,
                    src_attention_dropout_rate,
                    lora_list,
                    lora_rank,
                    lora_alpha,
                    lora_dropout,
                ),
                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                dropout_rate,
@@ -383,16 +394,10 @@
                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.decoders(
            x, tgt_mask, memory, memory_mask
        )
        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
        if self.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
            x, tgt_mask, memory, memory_mask
        )
            x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
        x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
        if self.normalize_before:
            hidden = self.after_norm(x)
@@ -407,10 +412,10 @@
    def score(self, ys, state, x):
        """Score."""
        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
        logp, state = self.forward_one_step(
            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
        )
        ys_mask = myutils.sequence_mask(
            torch.tensor([len(ys)], dtype=torch.int32), device=x.device
        )[:, :, None]
        logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
        return logp.squeeze(0), state
    
    def forward_asf2(
@@ -494,8 +499,12 @@
        for i in range(self.att_layer_num):
            decoder = self.decoders[i]
            x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
                x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
                chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
                x,
                memory,
                fsmn_cache=fsmn_cache[i],
                opt_cache=opt_cache[i],
                chunk_size=cache["chunk_size"],
                look_back=cache["decoder_chunk_look_back"],
            )
        if self.num_blocks - self.att_layer_num > 1:
@@ -507,9 +516,7 @@
                )
        for decoder in self.decoders3:
            x, memory, _, _ = decoder.forward_chunk(
                x, memory
            )
            x, memory, _, _ = decoder.forward_chunk(x, memory)
        if self.normalize_before:
            x = self.after_norm(x)
        if self.output_layer is not None:
@@ -581,21 +588,18 @@
        return y, new_cache
class DecoderLayerSANMExport(torch.nn.Module):
    def __init__(
        self,
        model
    ):
    def __init__(self, model):
        super().__init__()
        self.self_attn = model.self_attn
        self.src_attn = model.src_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
        self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
        self.norm2 = model.norm2 if hasattr(model, "norm2") else None
        self.norm3 = model.norm3 if hasattr(model, "norm3") else None
        self.size = model.size
    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -613,7 +617,6 @@
            residual = x
            x = self.norm3(x)
            x = residual + self.src_attn(x, memory, memory_mask)
        return x, tgt_mask, memory, memory_mask, cache
    
@@ -636,12 +639,7 @@
@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
class ParaformerSANMDecoderExport(torch.nn.Module):
    def __init__(self, model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True,
                 **kwargs
                 ):
    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
@@ -653,7 +651,6 @@
        
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
        
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
@@ -706,16 +703,12 @@
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
            x, tgt_mask, memory, memory_mask
        )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(x, tgt_mask, memory, memory_mask)
        if self.model.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
            x, tgt_mask, memory, memory_mask
        )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
        hidden = self.after_norm(x)
        # x = self.output_layer(x)
        
@@ -742,7 +735,9 @@
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        _, memory_mask = self.prepare_mask(memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
            tgt, tgt_mask, memory, memory_mask
        )
        attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    
@@ -761,15 +756,25 @@
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        _, memory_mask = self.prepare_mask(memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](
            tgt, tgt_mask, memory, memory_mask
        )
        attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    
    '''
    """
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
@@ -818,14 +823,12 @@
            for d in range(cache_num)
        })
        return ret
    '''
    """
    
@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
    def __init__(self, model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True, **kwargs):
    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
@@ -903,9 +906,7 @@
                    x, tgt_mask, memory, memory_mask, cache=in_cache
                )
                out_caches.append(out_cache)
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
            x, tgt_mask, memory, memory_mask
        )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
        x = self.after_norm(x)
        x = self.output_layer(x)
        
@@ -917,62 +918,63 @@
        acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
        acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
        cache_num = len(self.model.decoders)
        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        cache = [
            torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
                        dtype=torch.float32)
            torch.zeros(
                (2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
                dtype=torch.float32,
            )
            for _ in range(cache_num)
        ]
        return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
    
    def get_input_names(self):
        cache_num = len(self.model.decoders)
        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
               + ['in_cache_%d' % i for i in range(cache_num)]
        return ["enc", "enc_len", "acoustic_embeds", "acoustic_embeds_len"] + [
            "in_cache_%d" % i for i in range(cache_num)
        ]
    
    def get_output_names(self):
        cache_num = len(self.model.decoders)
        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        return ['logits', 'sample_ids'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
        return ["logits", "sample_ids"] + ["out_cache_%d" % i for i in range(cache_num)]
    
    def get_dynamic_axes(self):
        ret = {
            'enc': {
                0: 'batch_size',
                1: 'enc_length'
            "enc": {0: "batch_size", 1: "enc_length"},
            "acoustic_embeds": {0: "batch_size", 1: "token_length"},
            "enc_len": {
                0: "batch_size",
            },
            'acoustic_embeds': {
                0: 'batch_size',
                1: 'token_length'
            "acoustic_embeds_len": {
                0: "batch_size",
            },
            'enc_len': {
                0: 'batch_size',
            },
            'acoustic_embeds_len': {
                0: 'batch_size',
            },
        }
        cache_num = len(self.model.decoders)
        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        ret.update({
            'in_cache_%d' % d: {
                0: 'batch_size',
        ret.update(
            {
                "in_cache_%d"
                % d: {
                    0: "batch_size",
            }
            for d in range(cache_num)
        })
        ret.update({
            'out_cache_%d' % d: {
                0: 'batch_size',
            }
        )
        ret.update(
            {
                "out_cache_%d"
                % d: {
                    0: "batch_size",
            }
            for d in range(cache_num)
        })
            }
        )
        return ret
@@ -983,6 +985,7 @@
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(
            self,
            vocab_size: int,
@@ -1017,12 +1020,8 @@
            num_blocks,
            lambda lnum: DecoderLayer(
                attention_dim,
                MultiHeadedAttention(
                    attention_heads, attention_dim, self_attention_dropout_rate
                ),
                MultiHeadedAttention(
                    attention_heads, attention_dim, src_attention_dropout_rate
                ),
                MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                dropout_rate,
                normalize_before,
@@ -1060,23 +1059,17 @@
        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
        memory = hs_pad
        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
            memory.device
        )
        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
        # Padding for Longformer
        if memory_mask.shape[-1] != memory.shape[1]:
            padlen = memory.shape[1] - memory_mask.shape[-1]
            memory_mask = torch.nn.functional.pad(
                memory_mask, (0, padlen), "constant", False
            )
            memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
        # x = self.embed(tgt)
        x = tgt
        embeds_outputs = None
        for layer_id, decoder in enumerate(self.decoders):
            x, tgt_mask, memory, memory_mask = decoder(
                x, tgt_mask, memory, memory_mask
            )
            x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, memory_mask)
            if layer_id == self.embeds_id:
                embeds_outputs = x
        if self.normalize_before:
@@ -1090,23 +1083,25 @@
        else:
            return x, olens
@tables.register("decoder_classes", "ParaformerDecoderSANExport")
class ParaformerDecoderSANExport(torch.nn.Module):
    def __init__(self, model,
    def __init__(
        self,
        model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True, ):
        model_name="decoder",
        onnx: bool = True,
    ):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.transformer.decoder import DecoderLayerExport
        from funasr.models.transformer.attention import MultiHeadedAttentionExport
@@ -1149,9 +1144,7 @@
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        
        x = tgt
        x, tgt_mask, memory, memory_mask = self.model.decoders(
            x, tgt_mask, memory, memory_mask
        )
        x, tgt_mask, memory, memory_mask = self.model.decoders(x, tgt_mask, memory, memory_mask)
        x = self.after_norm(x)
        x = self.output_layer(x)
        
@@ -1163,7 +1156,9 @@
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
            torch.zeros(
                (1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)
            )
            for _ in range(cache_num)
        ]
        return (tgt, memory, pre_acoustic_embeds, cache)
@@ -1173,36 +1168,26 @@
    
    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
               + ['cache_%d' % i for i in range(cache_num)]
        return ["tgt", "memory", "pre_acoustic_embeds"] + ["cache_%d" % i for i in range(cache_num)]
    
    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['y'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
        return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
    
    def get_dynamic_axes(self):
        ret = {
            'tgt': {
                0: 'tgt_batch',
                1: 'tgt_length'
            "tgt": {0: "tgt_batch", 1: "tgt_length"},
            "memory": {0: "memory_batch", 1: "memory_length"},
            "pre_acoustic_embeds": {
                0: "acoustic_embeds_batch",
                1: "acoustic_embeds_length",
            },
            'memory': {
                0: 'memory_batch',
                1: 'memory_length'
            },
            'pre_acoustic_embeds': {
                0: 'acoustic_embeds_batch',
                1: 'acoustic_embeds_length',
            }
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update({
            'cache_%d' % d: {
                0: 'cache_%d_batch' % d,
                2: 'cache_%d_length' % d
            }
        ret.update(
            {
                "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
            for d in range(cache_num)
        })
            }
        )
        return ret