From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2

---
 funasr/models/scama/sanm_decoder.py |  552 +++++++++++++++++++++++++++++++-----------------------
 1 files changed, 314 insertions(+), 238 deletions(-)

diff --git a/funasr/models/paraformer/contextual_decoder.py b/funasr/models/scama/sanm_decoder.py
similarity index 70%
copy from funasr/models/paraformer/contextual_decoder.py
copy to funasr/models/scama/sanm_decoder.py
index 626cdef..53423d0 100644
--- a/funasr/models/paraformer/contextual_decoder.py
+++ b/funasr/models/scama/sanm_decoder.py
@@ -6,17 +6,38 @@
 import numpy as np
 
 from funasr.models.scama import utils as myutils
-from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
+from funasr.models.transformer.decoder import BaseTransformerDecoder
 
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
 from funasr.models.transformer.embedding import PositionalEncoding
 from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.transformer.repeat import repeat
-from funasr.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayerSANM(nn.Module):
+    """Single decoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        src_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
 
 
-class ContextualDecoderLayer(nn.Module):
+    """
+
     def __init__(
         self,
         size,
@@ -28,7 +49,7 @@
         concat_after=False,
     ):
         """Construct an DecoderLayer object."""
-        super(ContextualDecoderLayer, self).__init__()
+        super(DecoderLayerSANM, self).__init__()
         self.size = size
         self.self_attn = self_attn
         self.src_attn = src_attn
@@ -45,85 +66,161 @@
             self.concat_linear1 = nn.Linear(size + size, size)
             self.concat_linear2 = nn.Linear(size + size, size)
 
-    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
         # tgt = self.dropout(tgt)
-        if isinstance(tgt, Tuple):
-            tgt, _ = tgt
         residual = tgt
         if self.normalize_before:
             tgt = self.norm1(tgt)
         tgt = self.feed_forward(tgt)
 
         x = tgt
-        if self.normalize_before:
-            tgt = self.norm2(tgt)
-        if self.training:
-            cache = None
-        x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
-        x = residual + self.dropout(x)
-        x_self_attn = x
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, _ = self.self_attn(tgt, tgt_mask)
+            x = residual + self.dropout(x)
 
-        residual = x
-        if self.normalize_before:
-            x = self.norm3(x)
-        x = self.src_attn(x, memory, memory_mask)
-        x_src_attn = x
-
-        x = residual + self.dropout(x)
-        return x, tgt_mask, x_self_attn, x_src_attn
-
-
-class ContextualBiasDecoder(nn.Module):
-    def __init__(
-        self,
-        size,
-        src_attn,
-        dropout_rate,
-        normalize_before=True,
-    ):
-        """Construct an DecoderLayer object."""
-        super(ContextualBiasDecoder, self).__init__()
-        self.size = size
-        self.src_attn = src_attn
-        if src_attn is not None:
-            self.norm3 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.normalize_before = normalize_before
-
-    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
-        x = tgt
         if self.src_attn is not None:
+            residual = x
             if self.normalize_before:
                 x = self.norm3(x)
-            x =  self.dropout(self.src_attn(x, memory, memory_mask))
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
         return x, tgt_mask, memory, memory_mask, cache
 
+    def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
 
-class ContextualParaformerDecoder(ParaformerSANMDecoder):
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        # tgt = self.dropout(tgt)
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            if self.training:
+                cache = None
+            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
+            x = residual + x
+
+        return x, memory, fsmn_cache, opt_cache
+
+@register_class("decoder_classes", "FsmnDecoderSCAMAOpt")
+class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
-    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
+
     """
     def __init__(
-        self,
-        vocab_size: int,
-        encoder_output_size: int,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        self_attention_dropout_rate: float = 0.0,
-        src_attention_dropout_rate: float = 0.0,
-        input_layer: str = "embed",
-        use_output_layer: bool = True,
-        pos_enc_class=PositionalEncoding,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        att_layer_num: int = 6,
-        kernel_size: int = 21,
-        sanm_shfit: int = 0,
+            self,
+            vocab_size: int,
+            encoder_output_size: int,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            self_attention_dropout_rate: float = 0.0,
+            src_attention_dropout_rate: float = 0.0,
+            input_layer: str = "embed",
+            use_output_layer: bool = True,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            att_layer_num: int = 6,
+            kernel_size: int = 21,
+            sanm_shfit: int = None,
+            concat_embeds: bool = False,
+            attention_dim: int = None,
+            tf2torch_tensor_name_prefix_torch: str = "decoder",
+            tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+            embed_tensor_name_prefix_tf: str = None,
     ):
         super().__init__(
             vocab_size=vocab_size,
@@ -135,14 +232,12 @@
             pos_enc_class=pos_enc_class,
             normalize_before=normalize_before,
         )
+        if attention_dim is None:
+            attention_dim = encoder_output_size
 
-        attention_dim = encoder_output_size
-        if input_layer == 'none':
-            self.embed = None
         if input_layer == "embed":
             self.embed = torch.nn.Sequential(
                 torch.nn.Embedding(vocab_size, attention_dim),
-                # pos_enc_class(attention_dim, positional_dropout_rate),
             )
         elif input_layer == "linear":
             self.embed = torch.nn.Sequential(
@@ -168,14 +263,14 @@
         if sanm_shfit is None:
             sanm_shfit = (kernel_size - 1) // 2
         self.decoders = repeat(
-            att_layer_num - 1,
+            att_layer_num,
             lambda lnum: DecoderLayerSANM(
                 attention_dim,
                 MultiHeadedAttentionSANMDecoder(
                     attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                 ),
                 MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate
+                    attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
                 ),
                 PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
@@ -183,29 +278,6 @@
                 concat_after,
             ),
         )
-        self.dropout = nn.Dropout(dropout_rate)
-        self.bias_decoder = ContextualBiasDecoder(
-            size=attention_dim,
-            src_attn=MultiHeadedAttentionCrossAtt(
-                attention_heads, attention_dim, src_attention_dropout_rate
-            ),
-            dropout_rate=dropout_rate,
-            normalize_before=True,
-        )
-        self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
-        self.last_decoder = ContextualDecoderLayer(
-                attention_dim,
-                MultiHeadedAttentionSANMDecoder(
-                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
-                ),
-                MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
-                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            )
         if num_blocks - att_layer_num <= 0:
             self.decoders2 = None
         else:
@@ -214,7 +286,7 @@
                 lambda lnum: DecoderLayerSANM(
                     attention_dim,
                     MultiHeadedAttentionSANMDecoder(
-                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                     ),
                     None,
                     PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -236,16 +308,35 @@
                 concat_after,
             ),
         )
+        if concat_embeds:
+            self.embed_concat_ffn = repeat(
+                1,
+                lambda lnum: DecoderLayerSANM(
+                    attention_dim + encoder_output_size,
+                    None,
+                    None,
+                    PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
+                                                      adim=attention_dim),
+                    dropout_rate,
+                    normalize_before,
+                    concat_after,
+                ),
+            )
+        else:
+            self.embed_concat_ffn = None
+        self.concat_embeds = concat_embeds
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+        self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
 
     def forward(
-        self,
-        hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
-        ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
-        contextual_info: torch.Tensor,
-        clas_scale: float = 1.0,
-        return_hidden: bool = False,
+            self,
+            hs_pad: torch.Tensor,
+            hlens: torch.Tensor,
+            ys_in_pad: torch.Tensor,
+            ys_in_lens: torch.Tensor,
+            chunk_mask: torch.Tensor = None,
+            pre_acoustic_embeds: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
 
@@ -269,46 +360,122 @@
 
         memory = hs_pad
         memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        if chunk_mask is not None:
+            memory_mask = memory_mask * chunk_mask
+            if tgt_mask.size(1) != memory_mask.size(1):
+                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
 
-        x = tgt
+        x = self.embed(tgt)
+
+        if pre_acoustic_embeds is not None and self.concat_embeds:
+            x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+            x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+
         x, tgt_mask, memory, memory_mask, _ = self.decoders(
             x, tgt_mask, memory, memory_mask
         )
-        _, _, x_self_attn, x_src_attn = self.last_decoder(
-            x, tgt_mask, memory, memory_mask
-        )
-
-        # contextual paraformer related
-        contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
-        contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
-        cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
-
-        if self.bias_output is not None:
-            x = torch.cat([x_src_attn, cx*clas_scale], dim=2)
-            x = self.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
-            x = x_self_attn + self.dropout(x)
-
         if self.decoders2 is not None:
             x, tgt_mask, memory, memory_mask, _ = self.decoders2(
                 x, tgt_mask, memory, memory_mask
             )
-
         x, tgt_mask, memory, memory_mask, _ = self.decoders3(
             x, tgt_mask, memory, memory_mask
         )
         if self.normalize_before:
             x = self.after_norm(x)
-        olens = tgt_mask.sum(1)
-        if self.output_layer is not None and return_hidden is False:
+        if self.output_layer is not None:
             x = self.output_layer(x)
+
+        olens = tgt_mask.sum(1)
         return x, olens
 
-    def gen_tf2torch_map_dict(self):
+    def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
+        """Score."""
+        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+        logp, state = self.forward_one_step(
+            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
+            cache=state
+        )
+        return logp.squeeze(0), state
 
+    def forward_one_step(
+            self,
+            tgt: torch.Tensor,
+            tgt_mask: torch.Tensor,
+            memory: torch.Tensor,
+            memory_mask: torch.Tensor = None,
+            pre_acoustic_embeds: torch.Tensor = None,
+            cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Forward one step.
+
+        Args:
+            tgt: input token ids, int64 (batch, maxlen_out)
+            tgt_mask: input token mask,  (batch, maxlen_out)
+                      dtype=torch.uint8 in PyTorch 1.2-
+                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+            memory: encoded memory, float32  (batch, maxlen_in, feat)
+            cache: cached output list of (batch, max_time_out-1, size)
+        Returns:
+            y, cache: NN output value and cache per `self.decoders`.
+            y.shape` is (batch, maxlen_out, token)
+        """
+
+        x = tgt[:, -1:]
+        tgt_mask = None
+        x = self.embed(x)
+
+        if pre_acoustic_embeds is not None and self.concat_embeds:
+            x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+            x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+
+        if cache is None:
+            cache_layer_num = len(self.decoders)
+            if self.decoders2 is not None:
+                cache_layer_num += len(self.decoders2)
+            cache = [None] * cache_layer_num
+        new_cache = []
+        # for c, decoder in zip(cache, self.decoders):
+        for i in range(self.att_layer_num):
+            decoder = self.decoders[i]
+            c = cache[i]
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                x, tgt_mask, memory, memory_mask, cache=c
+            )
+            new_cache.append(c_ret)
+
+        if self.num_blocks - self.att_layer_num >= 1:
+            for i in range(self.num_blocks - self.att_layer_num):
+                j = i + self.att_layer_num
+                decoder = self.decoders2[i]
+                c = cache[j]
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                    x, tgt_mask, memory, memory_mask, cache=c
+                )
+                new_cache.append(c_ret)
+
+        for decoder in self.decoders3:
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
+                x, tgt_mask, memory, None, cache=None
+            )
+
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.output_layer is not None:
+            y = self.output_layer(y)
+            y = torch.log_softmax(y, dim=-1)
+
+        return y, new_cache
+
+    def gen_tf2torch_map_dict(self):
+    
         tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
         tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+        embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
         map_dict_local = {
-
+        
             ## decoder
             # ffn
             "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
@@ -346,7 +513,7 @@
                  "squeeze": 0,
                  "transpose": (1, 0),
                  },  # (256,1024),(1,1024,256)
-
+        
             # fsmn
             "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
                 {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
@@ -443,7 +610,7 @@
                  "squeeze": 0,
                  "transpose": (1, 0),
                  },  # (256,1024),(1,1024,256)
-
+        
             # embed_concat_ffn
             "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
                 {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
@@ -480,7 +647,7 @@
                  "squeeze": 0,
                  "transpose": (1, 0),
                  },  # (256,1024),(1,1024,256)
-
+        
             # out norm
             "{}.after_norm.weight".format(tensor_name_prefix_torch):
                 {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
@@ -492,17 +659,18 @@
                  "squeeze": None,
                  "transpose": None,
                  },  # (256,),(256,)
-
+        
             # in embed
             "{}.embed.0.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/w_embs".format(tensor_name_prefix_tf),
+                {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
                  "squeeze": None,
                  "transpose": None,
                  },  # (4235,256),(4235,256)
-
+        
             # out layer
             "{}.output_layer.weight".format(tensor_name_prefix_torch):
-                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
+                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
+                          "{}/w_embs".format(embed_tensor_name_prefix_tf)],
                  "squeeze": [None, None],
                  "transpose": [(1, 0), None],
                  },  # (4235,256),(256,4235)
@@ -512,56 +680,7 @@
                  "squeeze": [None, None],
                  "transpose": [None, None],
                  },  # (4235,),(4235,)
-
-            ## clas decoder
-            # src att
-            "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # dnn
-            "{}.bias_output.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": (2, 1, 0),
-                 },  # (1024,256),(1,256,1024)
-
+        
         }
         return map_dict_local
 
@@ -569,6 +688,7 @@
                          var_dict_tf,
                          var_dict_torch,
                          ):
+    
         map_dict = self.gen_tf2torch_map_dict()
         var_dict_torch_update = dict()
         decoder_layeridx_sets = set()
@@ -598,37 +718,13 @@
                         logging.info(
                             "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                           var_dict_tf[name_tf].shape))
-                elif names[1] == "last_decoder":
-                    layeridx = 15
-                    name_q = name.replace("last_decoder", "decoders.layeridx")
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-
-
+            
                 elif names[1] == "decoders2":
                     layeridx = int(names[2])
                     name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                     name_q = name_q.replace("decoders2", "decoders")
                     layeridx_bias = len(decoder_layeridx_sets)
-
+                
                     layeridx += layeridx_bias
                     if "decoders." in name:
                         decoder_layeridx_sets.add(layeridx)
@@ -649,11 +745,11 @@
                         logging.info(
                             "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                           var_dict_tf[name_tf].shape))
-
+            
                 elif names[1] == "decoders3":
                     layeridx = int(names[2])
                     name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
+                
                     layeridx_bias = 0
                     layeridx += layeridx_bias
                     if "decoders." in name:
@@ -675,29 +771,8 @@
                         logging.info(
                             "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                           var_dict_tf[name_tf].shape))
-                elif names[1] == "bias_decoder":
-                    name_q = name
-
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-
-
-                elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
+            
+                elif names[1] == "embed" or names[1] == "output_layer":
                     name_tf = map_dict[name]["name"]
                     if isinstance(name_tf, list):
                         idx_list = 0
@@ -720,7 +795,7 @@
                                                                                                    name_tf[idx_list],
                                                                                                    var_dict_tf[name_tf[
                                                                                                        idx_list]].shape))
-
+                
                     else:
                         data_tf = var_dict_tf[name_tf]
                         if map_dict[name]["squeeze"] is not None:
@@ -736,7 +811,7 @@
                         logging.info(
                             "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                           var_dict_tf[name_tf].shape))
-
+            
                 elif names[1] == "after_norm":
                     name_tf = map_dict[name]["name"]
                     data_tf = var_dict_tf[name_tf]
@@ -745,11 +820,11 @@
                     logging.info(
                         "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                       var_dict_tf[name_tf].shape))
-
+            
                 elif names[1] == "embed_concat_ffn":
                     layeridx = int(names[2])
                     name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
+                
                     layeridx_bias = 0
                     layeridx += layeridx_bias
                     if "decoders." in name:
@@ -771,5 +846,6 @@
                         logging.info(
                             "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                           var_dict_tf[name_tf].shape))
-
+    
         return var_dict_torch_update
+

--
Gitblit v1.9.1