zhifu gao
2024-12-25 3f8294b9d7deaa0cbdb0b2ef6f3802d46ae133a9
funasr/models/sanm/decoder.py
@@ -13,13 +13,17 @@
from funasr.models.scama import utils as myutils
from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.models.sanm.attention import (
    MultiHeadedAttentionSANMDecoder,
    MultiHeadedAttentionCrossAtt,
)
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.transformer.utils.repeat import repeat
from funasr.register import tables
class DecoderLayerSANM(nn.Module):
    """Single decoder layer module.
@@ -151,10 +155,11 @@
            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
        return x, tgt_mask, memory, memory_mask, cache
    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
    def forward_chunk(
        self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
    ):
        """Compute decoded features.
        Args:
@@ -202,7 +207,7 @@
    San-m: Memory equipped self-attention for end-to-end speech recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(
        self,
        vocab_size: int,
@@ -240,7 +245,7 @@
        )
        if attention_dim is None:
            attention_dim = encoder_output_size
        if input_layer == "embed":
            self.embed = torch.nn.Sequential(
                torch.nn.Embedding(vocab_size, attention_dim),
@@ -255,7 +260,7 @@
            )
        else:
            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
        self.normalize_before = normalize_before
        if self.normalize_before:
            self.after_norm = LayerNorm(attention_dim)
@@ -263,7 +268,7 @@
            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
        else:
            self.output_layer = None
        self.att_layer_num = att_layer_num
        self.num_blocks = num_blocks
        if sanm_shfit is None:
@@ -276,7 +281,10 @@
                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                ),
                MultiHeadedAttentionCrossAtt(
                    attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
                    attention_heads,
                    attention_dim,
                    src_attention_dropout_rate,
                    encoder_output_size=encoder_output_size,
                ),
                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                dropout_rate,
@@ -292,7 +300,10 @@
                lambda lnum: DecoderLayerSANM(
                    attention_dim,
                    MultiHeadedAttentionSANMDecoder(
                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                        attention_dim,
                        self_attention_dropout_rate,
                        kernel_size,
                        sanm_shfit=sanm_shfit,
                    ),
                    None,
                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -301,7 +312,7 @@
                    concat_after,
                ),
            )
        self.decoders3 = repeat(
            1,
            lambda lnum: DecoderLayerSANM(
@@ -321,8 +332,12 @@
                    attention_dim + encoder_output_size,
                    None,
                    None,
                    PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
                                                       adim=attention_dim),
                    PositionwiseFeedForwardDecoderSANM(
                        attention_dim + encoder_output_size,
                        linear_units,
                        dropout_rate,
                        adim=attention_dim,
                    ),
                    dropout_rate,
                    normalize_before,
                    concat_after,
@@ -334,7 +349,7 @@
        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
        self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
    def forward(
        self,
        hs_pad: torch.Tensor,
@@ -363,47 +378,54 @@
        """
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        if chunk_mask is not None:
            memory_mask = memory_mask * chunk_mask
            if tgt_mask.size(1) != memory_mask.size(1):
                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
        x = self.embed(tgt)
        if pre_acoustic_embeds is not None and self.concat_embeds:
            x = torch.cat((x, pre_acoustic_embeds), dim=-1)
            x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
        x, tgt_mask, memory, memory_mask, _ = self.decoders(
            x, tgt_mask, memory, memory_mask
        )
        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
        if self.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
            x, tgt_mask, memory, memory_mask
        )
            x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
        x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
        if self.normalize_before:
            x = self.after_norm(x)
        if self.output_layer is not None:
            x = self.output_layer(x)
        olens = tgt_mask.sum(1)
        return x, olens
    def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
    def score(
        self,
        ys,
        state,
        x,
        x_mask=None,
        pre_acoustic_embeds: torch.Tensor = None,
    ):
        """Score."""
        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
        ys_mask = myutils.sequence_mask(
            torch.tensor([len(ys)], dtype=torch.int32), device=x.device
        )[:, :, None]
        logp, state = self.forward_one_step(
            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
            cache=state
            ys.unsqueeze(0),
            ys_mask,
            x.unsqueeze(0),
            memory_mask=x_mask,
            pre_acoustic_embeds=pre_acoustic_embeds,
            cache=state,
        )
        return logp.squeeze(0), state
    def forward_one_step(
        self,
        tgt: torch.Tensor,
@@ -426,15 +448,15 @@
            y, cache: NN output value and cache per `self.decoders`.
            y.shape` is (batch, maxlen_out, token)
        """
        x = tgt[:, -1:]
        tgt_mask = None
        x = self.embed(x)
        if pre_acoustic_embeds is not None and self.concat_embeds:
            x = torch.cat((x, pre_acoustic_embeds), dim=-1)
            x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
        if cache is None:
            cache_layer_num = len(self.decoders)
            if self.decoders2 is not None:
@@ -449,7 +471,7 @@
                x, tgt_mask, memory, memory_mask, cache=c
            )
            new_cache.append(c_ret)
        if self.num_blocks - self.att_layer_num >= 1:
            for i in range(self.num_blocks - self.att_layer_num):
                j = i + self.att_layer_num
@@ -459,12 +481,12 @@
                    x, tgt_mask, memory, memory_mask, cache=c
                )
                new_cache.append(c_ret)
        for decoder in self.decoders3:
            x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
                x, tgt_mask, memory, None, cache=None
            )
        if self.normalize_before:
            y = self.after_norm(x[:, -1])
        else:
@@ -472,7 +494,5 @@
        if self.output_layer is not None:
            y = self.output_layer(y)
            y = torch.log_softmax(y, dim=-1)
        return y, new_cache