From f77c5803f4d61099e572be8d877b1c4a4d6087cd Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期三, 10 五月 2023 12:02:06 +0800
Subject: [PATCH] Merge pull request #485 from alibaba-damo-academy/main

---
 funasr/models/decoder/transformer_decoder.py |  428 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 427 insertions(+), 1 deletions(-)

diff --git a/funasr/models/decoder/transformer_decoder.py b/funasr/models/decoder/transformer_decoder.py
index aed7f20..45fdda8 100644
--- a/funasr/models/decoder/transformer_decoder.py
+++ b/funasr/models/decoder/transformer_decoder.py
@@ -13,6 +13,7 @@
 
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.modules.attention import MultiHeadedAttention
+from funasr.modules.attention import CosineDistanceAttention
 from funasr.modules.dynamic_conv import DynamicConvolution
 from funasr.modules.dynamic_conv2d import DynamicConvolution2D
 from funasr.modules.embedding import PositionalEncoding
@@ -763,4 +764,429 @@
                 normalize_before,
                 concat_after,
             ),
-        )
\ No newline at end of file
+        )
+
+class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
+    
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        spker_embedding_dim: int = 256,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        input_layer: str = "embed",
+        use_asr_output_layer: bool = True,
+        use_spk_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+    ):
+        assert check_argument_types()
+        super().__init__()
+        attention_dim = encoder_output_size
+
+        if input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(vocab_size, attention_dim),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        elif input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(vocab_size, attention_dim),
+                torch.nn.LayerNorm(attention_dim),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        else:
+            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+        self.normalize_before = normalize_before
+        if self.normalize_before:
+            self.after_norm = LayerNorm(attention_dim)
+        if use_asr_output_layer:
+            self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
+        else:
+            self.asr_output_layer = None
+
+        if use_spk_output_layer:
+            self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
+        else:
+            self.spk_output_layer = None
+
+        self.cos_distance_att = CosineDistanceAttention()
+
+        self.decoder1 = None
+        self.decoder2 = None
+        self.decoder3 = None
+        self.decoder4 = None
+
+    def forward(
+        self,
+        asr_hs_pad: torch.Tensor,
+        spk_hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        profile: torch.Tensor,
+        profile_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        
+        tgt = ys_in_pad
+        # tgt_mask: (B, 1, L)
+        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+        # m: (1, L, L)
+        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+        # tgt_mask: (B, L, L)
+        tgt_mask = tgt_mask & m
+
+        asr_memory = asr_hs_pad
+        spk_memory = spk_hs_pad
+        memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
+        # Spk decoder
+        x = self.embed(tgt)
+
+        x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
+            x, tgt_mask, asr_memory, spk_memory, memory_mask
+        )
+        x, tgt_mask, spk_memory, memory_mask = self.decoder2(
+            x, tgt_mask, spk_memory, memory_mask
+        )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.spk_output_layer is not None:
+            x = self.spk_output_layer(x)
+        dn, weights = self.cos_distance_att(x, profile, profile_lens)
+        # Asr decoder
+        x, tgt_mask, asr_memory, memory_mask = self.decoder3(
+            z, tgt_mask, asr_memory, memory_mask, dn
+        )
+        x, tgt_mask, asr_memory, memory_mask = self.decoder4(
+            x, tgt_mask, asr_memory, memory_mask
+        )
+
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.asr_output_layer is not None:
+            x = self.asr_output_layer(x)
+
+        olens = tgt_mask.sum(1)
+        return x, weights, olens
+
+
+    def forward_one_step(
+        self,
+        tgt: torch.Tensor,
+        tgt_mask: torch.Tensor,
+        asr_memory: torch.Tensor,
+        spk_memory: torch.Tensor,
+        profile: torch.Tensor,
+        cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        
+        x = self.embed(tgt)
+
+        if cache is None:
+            cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
+        new_cache = []
+        x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
+                x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
+        )
+        new_cache.append(x)
+        for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
+            x, tgt_mask, spk_memory, _ = decoder(
+                x, tgt_mask, spk_memory, None, cache=c
+            )
+            new_cache.append(x)
+        if self.normalize_before:
+            x = self.after_norm(x)
+        else:
+            x = x
+        if self.spk_output_layer is not None:
+            x = self.spk_output_layer(x)
+        dn, weights = self.cos_distance_att(x, profile, None)
+
+        x, tgt_mask, asr_memory, _ = self.decoder3(
+            z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
+        )
+        new_cache.append(x)
+
+        for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
+            x, tgt_mask, asr_memory, _ = decoder(
+                x, tgt_mask, asr_memory, None, cache=c
+            )
+            new_cache.append(x)
+
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.asr_output_layer is not None:
+            y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
+
+        return y, weights, new_cache
+
+    def score(self, ys, state, asr_enc, spk_enc, profile):
+        """Score."""
+        ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
+        logp, weights, state = self.forward_one_step(
+            ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
+        )
+        return logp.squeeze(0), weights.squeeze(), state
+
+class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        spker_embedding_dim: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        asr_num_blocks: int = 6,
+        spk_num_blocks: int = 3,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_asr_output_layer: bool = True,
+        use_spk_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+    ):
+        assert check_argument_types()
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            spker_embedding_dim=spker_embedding_dim,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_asr_output_layer=use_asr_output_layer,
+            use_spk_output_layer=use_spk_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+
+        self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
+            attention_dim,
+            MultiHeadedAttention(
+                attention_heads, attention_dim, self_attention_dropout_rate
+            ),
+            MultiHeadedAttention(
+                attention_heads, attention_dim, src_attention_dropout_rate
+            ),
+            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+            dropout_rate,
+            normalize_before,
+            concat_after,
+        )
+        self.decoder2 = repeat(
+            spk_num_blocks - 1,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, self_attention_dropout_rate
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        
+        
+        self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
+            attention_dim,
+            spker_embedding_dim,
+            MultiHeadedAttention(
+                attention_heads, attention_dim, src_attention_dropout_rate
+            ),
+            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+            dropout_rate,
+            normalize_before,
+            concat_after,
+        )
+        self.decoder4 = repeat(
+            asr_num_blocks - 1,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, self_attention_dropout_rate
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
+
+    def __init__(
+        self,
+        size,
+        self_attn,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
+        self.size = size
+        self.self_attn = self_attn
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
+        
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+
+        if cache is None:
+            tgt_q = tgt
+            tgt_q_mask = tgt_mask
+        else:
+            # compute only the last frame query keeping dim: max_time_out -> 1
+            assert cache.shape == (
+                tgt.shape[0],
+                tgt.shape[1] - 1,
+                self.size,
+            ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+            tgt_q = tgt[:, -1:, :]
+            residual = residual[:, -1:, :]
+            tgt_q_mask = None
+            if tgt_mask is not None:
+                tgt_q_mask = tgt_mask[:, -1:, :]
+
+        if self.concat_after:
+            tgt_concat = torch.cat(
+                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
+            )
+            x = residual + self.concat_linear1(tgt_concat)
+        else:
+            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+        if not self.normalize_before:
+            x = self.norm1(x)
+        z = x
+        
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
+
+        if self.concat_after:
+            x_concat = torch.cat(
+                (x, skip), dim=-1
+            )
+            x = residual + self.concat_linear2(x_concat)
+        else:
+            x = residual + self.dropout(skip)
+        if not self.normalize_before:
+            x = self.norm1(x)
+        
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+            
+        return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
+
+class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
+    
+    def __init__(
+        self,
+        size,
+        d_size,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
+        self.size = size
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        self.norm2 = LayerNorm(size)
+        self.norm3 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        self.spk_linear = nn.Linear(d_size, size, bias=False)
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
+        
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+
+        if cache is None:
+            tgt_q = tgt
+            tgt_q_mask = tgt_mask
+        else:
+            
+            tgt_q = tgt[:, -1:, :]
+            residual = residual[:, -1:, :]
+            tgt_q_mask = None
+            if tgt_mask is not None:
+                tgt_q_mask = tgt_mask[:, -1:, :]
+
+        x = tgt_q
+        if self.normalize_before:
+            x = self.norm2(x)
+        if self.concat_after:
+            x_concat = torch.cat(
+                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
+            )
+            x = residual + self.concat_linear2(x_concat)
+        else:
+            x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
+        if not self.normalize_before:
+            x = self.norm2(x)
+        residual = x
+
+        if dn!=None:
+            x = x + self.spk_linear(dn)
+        if self.normalize_before:
+            x = self.norm3(x)
+        
+        x = residual + self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm3(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+
+        return x, tgt_mask, memory, memory_mask
\ No newline at end of file

--
Gitblit v1.9.1