From 6427c834dfd97b1f05c6659cdc7ccf010bf82fe1 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 24 四月 2023 19:50:07 +0800
Subject: [PATCH] update

---
 funasr/models/decoder/sanm_decoder.py |  886 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 877 insertions(+), 9 deletions(-)

diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index a5db353..463918a 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -1,8 +1,10 @@
 from typing import List
 from typing import Tuple
-
+import logging
 import torch
 import torch.nn as nn
+import numpy as np
+
 from funasr.modules.streaming_utils import utils as myutils
 from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
 from typeguard import check_argument_types
@@ -92,6 +94,46 @@
         if self.self_attn:
             if self.normalize_before:
                 tgt = self.norm2(tgt)
+            x, _ = self.self_attn(tgt, tgt_mask)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        # tgt = self.dropout(tgt)
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
             if self.training:
                 cache = None
             x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
@@ -106,7 +148,6 @@
 
 
         return x, tgt_mask, memory, memory_mask, cache
-
 
 class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
     """
@@ -136,6 +177,9 @@
             sanm_shfit: int = None,
             concat_embeds: bool = False,
             attention_dim: int = None,
+            tf2torch_tensor_name_prefix_torch: str = "decoder",
+            tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+            embed_tensor_name_prefix_tf: str = None,
     ):
         assert check_argument_types()
         super().__init__(
@@ -241,6 +285,9 @@
         else:
             self.embed_concat_ffn = None
         self.concat_embeds = concat_embeds
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+        self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
 
     def forward(
             self,
@@ -352,7 +399,7 @@
         for i in range(self.att_layer_num):
             decoder = self.decoders[i]
             c = cache[i]
-            x, tgt_mask, memory, memory_mask, c_ret = decoder(
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                 x, tgt_mask, memory, memory_mask, cache=c
             )
             new_cache.append(c_ret)
@@ -362,13 +409,13 @@
                 j = i + self.att_layer_num
                 decoder = self.decoders2[i]
                 c = cache[j]
-                x, tgt_mask, memory, memory_mask, c_ret = decoder(
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                     x, tgt_mask, memory, memory_mask, cache=c
                 )
                 new_cache.append(c_ret)
 
         for decoder in self.decoders3:
-            x, tgt_mask, memory, memory_mask, _ = decoder(
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=None
             )
 
@@ -381,6 +428,387 @@
             y = torch.log_softmax(y, dim=-1)
 
         return y, new_cache
+
+    def gen_tf2torch_map_dict(self):
+    
+        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+        embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
+        map_dict_local = {
+        
+            ## decoder
+            # ffn
+            "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+        
+            # fsmn
+            "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
+                    tensor_name_prefix_tf),
+                    "squeeze": None,
+                    "transpose": None,
+                },  # (256,),(256,)
+            "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
+                    tensor_name_prefix_tf),
+                    "squeeze": None,
+                    "transpose": None,
+                },  # (256,),(256,)
+            "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
+                    tensor_name_prefix_tf),
+                    "squeeze": 0,
+                    "transpose": (1, 2, 0),
+                },  # (256,1,31),(1,31,256,1)
+            # src att
+            "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,256),(1,256,256)
+            "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,256),(1,256,256)
+            "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            # dnn
+            "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+        
+            # embed_concat_ffn
+            "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+        
+            # out norm
+            "{}.after_norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.after_norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+        
+            # in embed
+            "{}.embed.0.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (4235,256),(4235,256)
+        
+            # out layer
+            "{}.output_layer.weight".format(tensor_name_prefix_torch):
+                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
+                          "{}/w_embs".format(embed_tensor_name_prefix_tf)],
+                 "squeeze": [None, None],
+                 "transpose": [(1, 0), None],
+                 },  # (4235,256),(256,4235)
+            "{}.output_layer.bias".format(tensor_name_prefix_torch):
+                {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
+                          "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
+                 "squeeze": [None, None],
+                 "transpose": [None, None],
+                 },  # (4235,),(4235,)
+        
+        }
+        return map_dict_local
+
+    def convert_tf2torch(self,
+                         var_dict_tf,
+                         var_dict_torch,
+                         ):
+    
+        map_dict = self.gen_tf2torch_map_dict()
+        var_dict_torch_update = dict()
+        decoder_layeridx_sets = set()
+        for name in sorted(var_dict_torch.keys(), reverse=False):
+            names = name.split('.')
+            if names[0] == self.tf2torch_tensor_name_prefix_torch:
+                if names[1] == "decoders":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "decoders2":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                    name_q = name_q.replace("decoders2", "decoders")
+                    layeridx_bias = len(decoder_layeridx_sets)
+                
+                    layeridx += layeridx_bias
+                    if "decoders." in name:
+                        decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "decoders3":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    if "decoders." in name:
+                        decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "embed" or names[1] == "output_layer":
+                    name_tf = map_dict[name]["name"]
+                    if isinstance(name_tf, list):
+                        idx_list = 0
+                        if name_tf[idx_list] in var_dict_tf.keys():
+                            pass
+                        else:
+                            idx_list = 1
+                        data_tf = var_dict_tf[name_tf[idx_list]]
+                        if map_dict[name]["squeeze"][idx_list] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
+                        if map_dict[name]["transpose"][idx_list] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
+                                                                                                   name_tf[idx_list],
+                                                                                                   var_dict_tf[name_tf[
+                                                                                                       idx_list]].shape))
+                
+                    else:
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+                        if map_dict[name]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+                                                                                          var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "after_norm":
+                    name_tf = map_dict[name]["name"]
+                    data_tf = var_dict_tf[name_tf]
+                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                    var_dict_torch_update[name] = data_tf
+                    logging.info(
+                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+                                                                                      var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "embed_concat_ffn":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    if "decoders." in name:
+                        decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+    
+        return var_dict_torch_update
+
 
 class ParaformerSANMDecoder(BaseTransformerDecoder):
     """
@@ -407,6 +835,8 @@
         att_layer_num: int = 6,
         kernel_size: int = 21,
         sanm_shfit: int = 0,
+        tf2torch_tensor_name_prefix_torch: str = "decoder",
+        tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
     ):
         assert check_argument_types()
         super().__init__(
@@ -496,6 +926,8 @@
                 concat_after,
             ),
         )
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
 
     def forward(
         self,
@@ -554,6 +986,65 @@
         )
         return logp.squeeze(0), state
 
+    def forward_chunk(
+        self,
+        memory: torch.Tensor,
+        tgt: torch.Tensor,
+        cache: dict = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        x = tgt
+        if cache["decode_fsmn"] is None:
+            cache_layer_num = len(self.decoders)
+            if self.decoders2 is not None:
+                cache_layer_num += len(self.decoders2)
+            new_cache = [None] * cache_layer_num
+        else:
+            new_cache = cache["decode_fsmn"]
+        for i in range(self.att_layer_num):
+            decoder = self.decoders[i]
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
+                x, None, memory, None, cache=new_cache[i]
+            )
+            new_cache[i] = c_ret
+
+        if self.num_blocks - self.att_layer_num > 1:
+            for i in range(self.num_blocks - self.att_layer_num):
+                j = i + self.att_layer_num
+                decoder = self.decoders2[i]
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
+                    x, None, memory, None, cache=new_cache[j]
+                )
+                new_cache[j] = c_ret
+
+        for decoder in self.decoders3:
+
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
+                x, None, memory, None, cache=None
+            )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+        cache["decode_fsmn"] = new_cache
+        return x
+
     def forward_one_step(
         self,
         tgt: torch.Tensor,
@@ -585,7 +1076,7 @@
         for i in range(self.att_layer_num):
             decoder = self.decoders[i]
             c = cache[i]
-            x, tgt_mask, memory, memory_mask, c_ret = decoder(
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=c
             )
             new_cache.append(c_ret)
@@ -595,14 +1086,14 @@
                 j = i + self.att_layer_num
                 decoder = self.decoders2[i]
                 c = cache[j]
-                x, tgt_mask, memory, memory_mask, c_ret = decoder(
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                     x, tgt_mask, memory, None, cache=c
                 )
                 new_cache.append(c_ret)
 
         for decoder in self.decoders3:
 
-            x, tgt_mask, memory, memory_mask, _ = decoder(
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=None
             )
 
@@ -613,4 +1104,381 @@
         if self.output_layer is not None:
             y = torch.log_softmax(self.output_layer(y), dim=-1)
 
-        return y, new_cache
\ No newline at end of file
+        return y, new_cache
+
+    def gen_tf2torch_map_dict(self):
+    
+        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+        map_dict_local = {
+        
+            ## decoder
+            # ffn
+            "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+        
+            # fsmn
+            "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
+                    tensor_name_prefix_tf),
+                    "squeeze": None,
+                    "transpose": None,
+                },  # (256,),(256,)
+            "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
+                    tensor_name_prefix_tf),
+                    "squeeze": None,
+                    "transpose": None,
+                },  # (256,),(256,)
+            "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
+                    tensor_name_prefix_tf),
+                    "squeeze": 0,
+                    "transpose": (1, 2, 0),
+                },  # (256,1,31),(1,31,256,1)
+            # src att
+            "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,256),(1,256,256)
+            "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,256),(1,256,256)
+            "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            # dnn
+            "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+        
+            # embed_concat_ffn
+            "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+        
+            # out norm
+            "{}.after_norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.after_norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+        
+            # in embed
+            "{}.embed.0.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/w_embs".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (4235,256),(4235,256)
+        
+            # out layer
+            "{}.output_layer.weight".format(tensor_name_prefix_torch):
+                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
+                 "squeeze": [None, None],
+                 "transpose": [(1, 0), None],
+                 },  # (4235,256),(256,4235)
+            "{}.output_layer.bias".format(tensor_name_prefix_torch):
+                {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
+                          "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
+                 "squeeze": [None, None],
+                 "transpose": [None, None],
+                 },  # (4235,),(4235,)
+        
+        }
+        return map_dict_local
+
+    def convert_tf2torch(self,
+                         var_dict_tf,
+                         var_dict_torch,
+                         ):
+        map_dict = self.gen_tf2torch_map_dict()
+        var_dict_torch_update = dict()
+        decoder_layeridx_sets = set()
+        for name in sorted(var_dict_torch.keys(), reverse=False):
+            names = name.split('.')
+            if names[0] == self.tf2torch_tensor_name_prefix_torch:
+                if names[1] == "decoders":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "decoders2":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                    name_q = name_q.replace("decoders2", "decoders")
+                    layeridx_bias = len(decoder_layeridx_sets)
+                
+                    layeridx += layeridx_bias
+                    if "decoders." in name:
+                        decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "decoders3":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    if "decoders." in name:
+                        decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "embed" or names[1] == "output_layer":
+                    name_tf = map_dict[name]["name"]
+                    if isinstance(name_tf, list):
+                        idx_list = 0
+                        if name_tf[idx_list] in var_dict_tf.keys():
+                            pass
+                        else:
+                            idx_list = 1
+                        data_tf = var_dict_tf[name_tf[idx_list]]
+                        if map_dict[name]["squeeze"][idx_list] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
+                        if map_dict[name]["transpose"][idx_list] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
+                                                                                                   name_tf[idx_list],
+                                                                                                   var_dict_tf[name_tf[
+                                                                                                       idx_list]].shape))
+                
+                    else:
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+                        if map_dict[name]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+                                                                                          var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "after_norm":
+                    name_tf = map_dict[name]["name"]
+                    data_tf = var_dict_tf[name_tf]
+                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                    var_dict_torch_update[name] = data_tf
+                    logging.info(
+                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+                                                                                      var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "embed_concat_ffn":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    if "decoders." in name:
+                        decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+    
+        return var_dict_torch_update

--
Gitblit v1.9.1