From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2

---
 funasr/models/sa_asr/transformer_decoder.py |  442 +-----------------------------------------------------
 1 files changed, 11 insertions(+), 431 deletions(-)

diff --git a/funasr/models/transformer/transformer_decoder.py b/funasr/models/sa_asr/transformer_decoder.py
similarity index 65%
rename from funasr/models/transformer/transformer_decoder.py
rename to funasr/models/sa_asr/transformer_decoder.py
index b2bea68..3319212 100644
--- a/funasr/models/transformer/transformer_decoder.py
+++ b/funasr/models/sa_asr/transformer_decoder.py
@@ -10,9 +10,9 @@
 import torch
 from torch import nn
 
-from funasr.models.decoder.abs_decoder import AbsDecoder
+
 from funasr.models.transformer.attention import MultiHeadedAttention
-from funasr.models.transformer.attention import CosineDistanceAttention
+from funasr.models.sa_asr.attention import CosineDistanceAttention
 from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
 from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
 from funasr.models.transformer.embedding import PositionalEncoding
@@ -24,9 +24,10 @@
 from funasr.models.transformer.positionwise_feed_forward import (
     PositionwiseFeedForward,  # noqa: H301
 )
-from funasr.models.transformer.repeat import repeat
+from funasr.models.transformer.utils.repeat import repeat
 from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
 
+from funasr.utils.register import register_class, registry_tables
 
 class DecoderLayer(nn.Module):
     """Single decoder layer module.
@@ -150,7 +151,7 @@
         return x, tgt_mask, memory, memory_mask
 
 
-class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
+class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
     """Base class of Transfomer decoder module.
 
     Args:
@@ -352,7 +353,7 @@
         state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
         return logp, state_list
 
-
+@register_class("decoder_classes", "TransformerDecoder")
 class TransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -401,6 +402,7 @@
         )
 
 
+@register_class("decoder_classes", "ParaformerDecoderSAN")
 class ParaformerDecoderSAN(BaseTransformerDecoder):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -514,7 +516,7 @@
         else:
             return x, olens
 
-
+@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
 class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -575,7 +577,7 @@
             ),
         )
 
-
+@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
 class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -637,6 +639,7 @@
         )
 
 
+@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
 class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -697,7 +700,7 @@
             ),
         )
 
-
+@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
 class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -757,426 +760,3 @@
                 concat_after,
             ),
         )
-
-class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
-    
-    def __init__(
-        self,
-        vocab_size: int,
-        encoder_output_size: int,
-        spker_embedding_dim: int = 256,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        input_layer: str = "embed",
-        use_asr_output_layer: bool = True,
-        use_spk_output_layer: bool = True,
-        pos_enc_class=PositionalEncoding,
-        normalize_before: bool = True,
-    ):
-        super().__init__()
-        attention_dim = encoder_output_size
-
-        if input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(vocab_size, attention_dim),
-                pos_enc_class(attention_dim, positional_dropout_rate),
-            )
-        elif input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(vocab_size, attention_dim),
-                torch.nn.LayerNorm(attention_dim),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(attention_dim, positional_dropout_rate),
-            )
-        else:
-            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-
-        self.normalize_before = normalize_before
-        if self.normalize_before:
-            self.after_norm = LayerNorm(attention_dim)
-        if use_asr_output_layer:
-            self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
-        else:
-            self.asr_output_layer = None
-
-        if use_spk_output_layer:
-            self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
-        else:
-            self.spk_output_layer = None
-
-        self.cos_distance_att = CosineDistanceAttention()
-
-        self.decoder1 = None
-        self.decoder2 = None
-        self.decoder3 = None
-        self.decoder4 = None
-
-    def forward(
-        self,
-        asr_hs_pad: torch.Tensor,
-        spk_hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
-        ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
-        profile: torch.Tensor,
-        profile_lens: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        
-        tgt = ys_in_pad
-        # tgt_mask: (B, 1, L)
-        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
-        # m: (1, L, L)
-        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
-        # tgt_mask: (B, L, L)
-        tgt_mask = tgt_mask & m
-
-        asr_memory = asr_hs_pad
-        spk_memory = spk_hs_pad
-        memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
-        # Spk decoder
-        x = self.embed(tgt)
-
-        x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
-            x, tgt_mask, asr_memory, spk_memory, memory_mask
-        )
-        x, tgt_mask, spk_memory, memory_mask = self.decoder2(
-            x, tgt_mask, spk_memory, memory_mask
-        )
-        if self.normalize_before:
-            x = self.after_norm(x)
-        if self.spk_output_layer is not None:
-            x = self.spk_output_layer(x)
-        dn, weights = self.cos_distance_att(x, profile, profile_lens)
-        # Asr decoder
-        x, tgt_mask, asr_memory, memory_mask = self.decoder3(
-            z, tgt_mask, asr_memory, memory_mask, dn
-        )
-        x, tgt_mask, asr_memory, memory_mask = self.decoder4(
-            x, tgt_mask, asr_memory, memory_mask
-        )
-
-        if self.normalize_before:
-            x = self.after_norm(x)
-        if self.asr_output_layer is not None:
-            x = self.asr_output_layer(x)
-
-        olens = tgt_mask.sum(1)
-        return x, weights, olens
-
-
-    def forward_one_step(
-        self,
-        tgt: torch.Tensor,
-        tgt_mask: torch.Tensor,
-        asr_memory: torch.Tensor,
-        spk_memory: torch.Tensor,
-        profile: torch.Tensor,
-        cache: List[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
-        
-        x = self.embed(tgt)
-
-        if cache is None:
-            cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
-        new_cache = []
-        x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
-                x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
-        )
-        new_cache.append(x)
-        for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
-            x, tgt_mask, spk_memory, _ = decoder(
-                x, tgt_mask, spk_memory, None, cache=c
-            )
-            new_cache.append(x)
-        if self.normalize_before:
-            x = self.after_norm(x)
-        else:
-            x = x
-        if self.spk_output_layer is not None:
-            x = self.spk_output_layer(x)
-        dn, weights = self.cos_distance_att(x, profile, None)
-
-        x, tgt_mask, asr_memory, _ = self.decoder3(
-            z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
-        )
-        new_cache.append(x)
-
-        for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
-            x, tgt_mask, asr_memory, _ = decoder(
-                x, tgt_mask, asr_memory, None, cache=c
-            )
-            new_cache.append(x)
-
-        if self.normalize_before:
-            y = self.after_norm(x[:, -1])
-        else:
-            y = x[:, -1]
-        if self.asr_output_layer is not None:
-            y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
-
-        return y, weights, new_cache
-
-    def score(self, ys, state, asr_enc, spk_enc, profile):
-        """Score."""
-        ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
-        logp, weights, state = self.forward_one_step(
-            ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
-        )
-        return logp.squeeze(0), weights.squeeze(), state
-
-class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
-    def __init__(
-        self,
-        vocab_size: int,
-        encoder_output_size: int,
-        spker_embedding_dim: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        asr_num_blocks: int = 6,
-        spk_num_blocks: int = 3,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        self_attention_dropout_rate: float = 0.0,
-        src_attention_dropout_rate: float = 0.0,
-        input_layer: str = "embed",
-        use_asr_output_layer: bool = True,
-        use_spk_output_layer: bool = True,
-        pos_enc_class=PositionalEncoding,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-    ):
-        super().__init__(
-            vocab_size=vocab_size,
-            encoder_output_size=encoder_output_size,
-            spker_embedding_dim=spker_embedding_dim,
-            dropout_rate=dropout_rate,
-            positional_dropout_rate=positional_dropout_rate,
-            input_layer=input_layer,
-            use_asr_output_layer=use_asr_output_layer,
-            use_spk_output_layer=use_spk_output_layer,
-            pos_enc_class=pos_enc_class,
-            normalize_before=normalize_before,
-        )
-
-        attention_dim = encoder_output_size
-
-        self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
-            attention_dim,
-            MultiHeadedAttention(
-                attention_heads, attention_dim, self_attention_dropout_rate
-            ),
-            MultiHeadedAttention(
-                attention_heads, attention_dim, src_attention_dropout_rate
-            ),
-            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-            dropout_rate,
-            normalize_before,
-            concat_after,
-        )
-        self.decoder2 = repeat(
-            spk_num_blocks - 1,
-            lambda lnum: DecoderLayer(
-                attention_dim,
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, self_attention_dropout_rate
-                ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
-                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        
-        
-        self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
-            attention_dim,
-            spker_embedding_dim,
-            MultiHeadedAttention(
-                attention_heads, attention_dim, src_attention_dropout_rate
-            ),
-            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-            dropout_rate,
-            normalize_before,
-            concat_after,
-        )
-        self.decoder4 = repeat(
-            asr_num_blocks - 1,
-            lambda lnum: DecoderLayer(
-                attention_dim,
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, self_attention_dropout_rate
-                ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
-                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-
-class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
-
-    def __init__(
-        self,
-        size,
-        self_attn,
-        src_attn,
-        feed_forward,
-        dropout_rate,
-        normalize_before=True,
-        concat_after=False,
-    ):
-        """Construct an DecoderLayer object."""
-        super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
-        self.size = size
-        self.self_attn = self_attn
-        self.src_attn = src_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(size)
-        self.norm2 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        if self.concat_after:
-            self.concat_linear1 = nn.Linear(size + size, size)
-            self.concat_linear2 = nn.Linear(size + size, size)
-
-    def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
-        
-        residual = tgt
-        if self.normalize_before:
-            tgt = self.norm1(tgt)
-
-        if cache is None:
-            tgt_q = tgt
-            tgt_q_mask = tgt_mask
-        else:
-            # compute only the last frame query keeping dim: max_time_out -> 1
-            assert cache.shape == (
-                tgt.shape[0],
-                tgt.shape[1] - 1,
-                self.size,
-            ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
-            tgt_q = tgt[:, -1:, :]
-            residual = residual[:, -1:, :]
-            tgt_q_mask = None
-            if tgt_mask is not None:
-                tgt_q_mask = tgt_mask[:, -1:, :]
-
-        if self.concat_after:
-            tgt_concat = torch.cat(
-                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
-            )
-            x = residual + self.concat_linear1(tgt_concat)
-        else:
-            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
-        if not self.normalize_before:
-            x = self.norm1(x)
-        z = x
-        
-        residual = x
-        if self.normalize_before:
-            x = self.norm1(x)
-
-        skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
-
-        if self.concat_after:
-            x_concat = torch.cat(
-                (x, skip), dim=-1
-            )
-            x = residual + self.concat_linear2(x_concat)
-        else:
-            x = residual + self.dropout(skip)
-        if not self.normalize_before:
-            x = self.norm1(x)
-        
-        residual = x
-        if self.normalize_before:
-            x = self.norm2(x)
-        x = residual + self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm2(x)
-
-        if cache is not None:
-            x = torch.cat([cache, x], dim=1)
-            
-        return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
-
-class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
-    
-    def __init__(
-        self,
-        size,
-        d_size,
-        src_attn,
-        feed_forward,
-        dropout_rate,
-        normalize_before=True,
-        concat_after=False,
-    ):
-        """Construct an DecoderLayer object."""
-        super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
-        self.size = size
-        self.src_attn = src_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(size)
-        self.norm2 = LayerNorm(size)
-        self.norm3 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        self.spk_linear = nn.Linear(d_size, size, bias=False)
-        if self.concat_after:
-            self.concat_linear1 = nn.Linear(size + size, size)
-            self.concat_linear2 = nn.Linear(size + size, size)
-
-    def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
-        
-        residual = tgt
-        if self.normalize_before:
-            tgt = self.norm1(tgt)
-
-        if cache is None:
-            tgt_q = tgt
-            tgt_q_mask = tgt_mask
-        else:
-            
-            tgt_q = tgt[:, -1:, :]
-            residual = residual[:, -1:, :]
-            tgt_q_mask = None
-            if tgt_mask is not None:
-                tgt_q_mask = tgt_mask[:, -1:, :]
-
-        x = tgt_q
-        if self.normalize_before:
-            x = self.norm2(x)
-        if self.concat_after:
-            x_concat = torch.cat(
-                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
-            )
-            x = residual + self.concat_linear2(x_concat)
-        else:
-            x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
-        if not self.normalize_before:
-            x = self.norm2(x)
-        residual = x
-
-        if dn!=None:
-            x = x + self.spk_linear(dn)
-        if self.normalize_before:
-            x = self.norm3(x)
-        
-        x = residual + self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm3(x)
-
-        if cache is not None:
-            x = torch.cat([cache, x], dim=1)
-
-        return x, tgt_mask, memory, memory_mask
\ No newline at end of file

--
Gitblit v1.9.1