From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/scama/decoder.py |  520 +++++++++------------------------------------------------
 1 files changed, 81 insertions(+), 439 deletions(-)

diff --git a/funasr/models/scama/decoder.py b/funasr/models/scama/decoder.py
index 9dcb9da..31b2357 100644
--- a/funasr/models/scama/decoder.py
+++ b/funasr/models/scama/decoder.py
@@ -13,13 +13,17 @@
 from funasr.models.scama import utils as myutils
 from funasr.models.transformer.decoder import BaseTransformerDecoder
 
-from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import (
+    MultiHeadedAttentionSANMDecoder,
+    MultiHeadedAttentionCrossAtt,
+)
 from funasr.models.transformer.embedding import PositionalEncoding
 from funasr.models.transformer.layer_norm import LayerNorm
 from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
 from funasr.models.transformer.utils.repeat import repeat
 
 from funasr.register import tables
+
 
 class DecoderLayerSANM(nn.Module):
     """Single decoder layer module.
@@ -151,10 +155,11 @@
 
             x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
 
-
         return x, tgt_mask, memory, memory_mask, cache
 
-    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+    def forward_chunk(
+        self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
+    ):
         """Compute decoded features.
 
         Args:
@@ -194,6 +199,7 @@
 
         return x, memory, fsmn_cache, opt_cache
 
+
 @tables.register("decoder_classes", "FsmnDecoderSCAMAOpt")
 class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
     """
@@ -201,31 +207,31 @@
     SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
     https://arxiv.org/abs/2006.01712
     """
-    
+
     def __init__(
-            self,
-            vocab_size: int,
-            encoder_output_size: int,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            self_attention_dropout_rate: float = 0.0,
-            src_attention_dropout_rate: float = 0.0,
-            input_layer: str = "embed",
-            use_output_layer: bool = True,
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            att_layer_num: int = 6,
-            kernel_size: int = 21,
-            sanm_shfit: int = None,
-            concat_embeds: bool = False,
-            attention_dim: int = None,
-            tf2torch_tensor_name_prefix_torch: str = "decoder",
-            tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
-            embed_tensor_name_prefix_tf: str = None,
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        att_layer_num: int = 6,
+        kernel_size: int = 21,
+        sanm_shfit: int = None,
+        concat_embeds: bool = False,
+        attention_dim: int = None,
+        tf2torch_tensor_name_prefix_torch: str = "decoder",
+        tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+        embed_tensor_name_prefix_tf: str = None,
     ):
         super().__init__(
             vocab_size=vocab_size,
@@ -275,7 +281,10 @@
                     attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                 ),
                 MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
+                    attention_heads,
+                    attention_dim,
+                    src_attention_dropout_rate,
+                    encoder_output_size=encoder_output_size,
                 ),
                 PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
@@ -291,7 +300,10 @@
                 lambda lnum: DecoderLayerSANM(
                     attention_dim,
                     MultiHeadedAttentionSANMDecoder(
-                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+                        attention_dim,
+                        self_attention_dropout_rate,
+                        kernel_size,
+                        sanm_shfit=sanm_shfit,
                     ),
                     None,
                     PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -320,8 +332,12 @@
                     attention_dim + encoder_output_size,
                     None,
                     None,
-                    PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
-                                                      adim=attention_dim),
+                    PositionwiseFeedForwardDecoderSANM(
+                        attention_dim + encoder_output_size,
+                        linear_units,
+                        dropout_rate,
+                        adim=attention_dim,
+                    ),
                     dropout_rate,
                     normalize_before,
                     concat_after,
@@ -335,13 +351,13 @@
         self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
 
     def forward(
-            self,
-            hs_pad: torch.Tensor,
-            hlens: torch.Tensor,
-            ys_in_pad: torch.Tensor,
-            ys_in_lens: torch.Tensor,
-            chunk_mask: torch.Tensor = None,
-            pre_acoustic_embeds: torch.Tensor = None,
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        chunk_mask: torch.Tensor = None,
+        pre_acoustic_embeds: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
 
@@ -376,16 +392,10 @@
             x = torch.cat((x, pre_acoustic_embeds), dim=-1)
             x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
 
-        x, tgt_mask, memory, memory_mask, _ = self.decoders(
-            x, tgt_mask, memory, memory_mask
-        )
+        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
         if self.decoders2 is not None:
-            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
-                x, tgt_mask, memory, memory_mask
-            )
-        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
-            x, tgt_mask, memory, memory_mask
-        )
+            x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
+        x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
         if self.normalize_before:
             x = self.after_norm(x)
         if self.output_layer is not None:
@@ -394,23 +404,36 @@
         olens = tgt_mask.sum(1)
         return x, olens
 
-    def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
+    def score(
+        self,
+        ys,
+        state,
+        x,
+        x_mask=None,
+        pre_acoustic_embeds: torch.Tensor = None,
+    ):
         """Score."""
-        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+        ys_mask = myutils.sequence_mask(
+            torch.tensor([len(ys)], dtype=torch.int32), device=x.device
+        )[:, :, None]
         logp, state = self.forward_one_step(
-            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
-            cache=state
+            ys.unsqueeze(0),
+            ys_mask,
+            x.unsqueeze(0),
+            memory_mask=x_mask,
+            pre_acoustic_embeds=pre_acoustic_embeds,
+            cache=state,
         )
         return logp.squeeze(0), state
 
     def forward_one_step(
-            self,
-            tgt: torch.Tensor,
-            tgt_mask: torch.Tensor,
-            memory: torch.Tensor,
-            memory_mask: torch.Tensor = None,
-            pre_acoustic_embeds: torch.Tensor = None,
-            cache: List[torch.Tensor] = None,
+        self,
+        tgt: torch.Tensor,
+        tgt_mask: torch.Tensor,
+        memory: torch.Tensor,
+        memory_mask: torch.Tensor = None,
+        pre_acoustic_embeds: torch.Tensor = None,
+        cache: List[torch.Tensor] = None,
     ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
         """Forward one step.
 
@@ -473,384 +496,3 @@
             y = torch.log_softmax(y, dim=-1)
 
         return y, new_cache
-
-    def gen_tf2torch_map_dict(self):
-    
-        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
-        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
-        embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
-        map_dict_local = {
-        
-            ## decoder
-            # ffn
-            "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-        
-            # fsmn
-            "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
-                    tensor_name_prefix_tf),
-                    "squeeze": None,
-                    "transpose": None,
-                },  # (256,),(256,)
-            "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
-                    tensor_name_prefix_tf),
-                    "squeeze": None,
-                    "transpose": None,
-                },  # (256,),(256,)
-            "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
-                    tensor_name_prefix_tf),
-                    "squeeze": 0,
-                    "transpose": (1, 2, 0),
-                },  # (256,1,31),(1,31,256,1)
-            # src att
-            "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # dnn
-            "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-        
-            # embed_concat_ffn
-            "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-        
-            # out norm
-            "{}.after_norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.after_norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-        
-            # in embed
-            "{}.embed.0.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (4235,256),(4235,256)
-        
-            # out layer
-            "{}.output_layer.weight".format(tensor_name_prefix_torch):
-                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
-                          "{}/w_embs".format(embed_tensor_name_prefix_tf)],
-                 "squeeze": [None, None],
-                 "transpose": [(1, 0), None],
-                 },  # (4235,256),(256,4235)
-            "{}.output_layer.bias".format(tensor_name_prefix_torch):
-                {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
-                          "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
-                 "squeeze": [None, None],
-                 "transpose": [None, None],
-                 },  # (4235,),(4235,)
-        
-        }
-        return map_dict_local
-
-    def convert_tf2torch(self,
-                         var_dict_tf,
-                         var_dict_torch,
-                         ):
-    
-        map_dict = self.gen_tf2torch_map_dict()
-        var_dict_torch_update = dict()
-        decoder_layeridx_sets = set()
-        for name in sorted(var_dict_torch.keys(), reverse=False):
-            names = name.split('.')
-            if names[0] == self.tf2torch_tensor_name_prefix_torch:
-                if names[1] == "decoders":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "decoders2":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                    name_q = name_q.replace("decoders2", "decoders")
-                    layeridx_bias = len(decoder_layeridx_sets)
-                
-                    layeridx += layeridx_bias
-                    if "decoders." in name:
-                        decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "decoders3":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    if "decoders." in name:
-                        decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "embed" or names[1] == "output_layer":
-                    name_tf = map_dict[name]["name"]
-                    if isinstance(name_tf, list):
-                        idx_list = 0
-                        if name_tf[idx_list] in var_dict_tf.keys():
-                            pass
-                        else:
-                            idx_list = 1
-                        data_tf = var_dict_tf[name_tf[idx_list]]
-                        if map_dict[name]["squeeze"][idx_list] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
-                        if map_dict[name]["transpose"][idx_list] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
-                                                                                                   name_tf[idx_list],
-                                                                                                   var_dict_tf[name_tf[
-                                                                                                       idx_list]].shape))
-                
-                    else:
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
-                        if map_dict[name]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "after_norm":
-                    name_tf = map_dict[name]["name"]
-                    data_tf = var_dict_tf[name_tf]
-                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                    var_dict_torch_update[name] = data_tf
-                    logging.info(
-                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                      var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "embed_concat_ffn":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    if "decoders." in name:
-                        decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-    
-        return var_dict_torch_update
-

--
Gitblit v1.9.1