zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/sanm/decoder.py
@@ -13,13 +13,17 @@
from funasr.models.scama import utils as myutils
from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.models.sanm.attention import (
    MultiHeadedAttentionSANMDecoder,
    MultiHeadedAttentionCrossAtt,
)
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.transformer.utils.repeat import repeat
from funasr.register import tables
class DecoderLayerSANM(nn.Module):
    """Single decoder layer module.
@@ -151,10 +155,11 @@
            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
        return x, tgt_mask, memory, memory_mask, cache
    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
    def forward_chunk(
        self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
    ):
        """Compute decoded features.
        Args:
@@ -276,7 +281,10 @@
                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                ),
                MultiHeadedAttentionCrossAtt(
                    attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
                    attention_heads,
                    attention_dim,
                    src_attention_dropout_rate,
                    encoder_output_size=encoder_output_size,
                ),
                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                dropout_rate,
@@ -292,7 +300,10 @@
                lambda lnum: DecoderLayerSANM(
                    attention_dim,
                    MultiHeadedAttentionSANMDecoder(
                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                        attention_dim,
                        self_attention_dropout_rate,
                        kernel_size,
                        sanm_shfit=sanm_shfit,
                    ),
                    None,
                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -321,8 +332,12 @@
                    attention_dim + encoder_output_size,
                    None,
                    None,
                    PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
                                                       adim=attention_dim),
                    PositionwiseFeedForwardDecoderSANM(
                        attention_dim + encoder_output_size,
                        linear_units,
                        dropout_rate,
                        adim=attention_dim,
                    ),
                    dropout_rate,
                    normalize_before,
                    concat_after,
@@ -377,16 +392,10 @@
            x = torch.cat((x, pre_acoustic_embeds), dim=-1)
            x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
        
        x, tgt_mask, memory, memory_mask, _ = self.decoders(
            x, tgt_mask, memory, memory_mask
        )
        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
        if self.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
            x, tgt_mask, memory, memory_mask
        )
            x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
        x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
        if self.normalize_before:
            x = self.after_norm(x)
        if self.output_layer is not None:
@@ -395,12 +404,25 @@
        olens = tgt_mask.sum(1)
        return x, olens
    
    def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
    def score(
        self,
        ys,
        state,
        x,
        x_mask=None,
        pre_acoustic_embeds: torch.Tensor = None,
    ):
        """Score."""
        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
        ys_mask = myutils.sequence_mask(
            torch.tensor([len(ys)], dtype=torch.int32), device=x.device
        )[:, :, None]
        logp, state = self.forward_one_step(
            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
            cache=state
            ys.unsqueeze(0),
            ys_mask,
            x.unsqueeze(0),
            memory_mask=x_mask,
            pre_acoustic_embeds=pre_acoustic_embeds,
            cache=state,
        )
        return logp.squeeze(0), state
    
@@ -474,5 +496,3 @@
            y = torch.log_softmax(y, dim=-1)
        
        return y, new_cache