From adcee8828ef5d78b575043954deb662a35e318f7 Mon Sep 17 00:00:00 2001
From: huangmingming <huangmingming@deepscience.cn>
Date: 星期一, 30 一月 2023 16:02:54 +0800
Subject: [PATCH] update the minimum size of audio
---
funasr/models/decoder/sanm_decoder.py | 774 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 772 insertions(+), 2 deletions(-)
diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index a5db353..ab03f0b 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -1,8 +1,10 @@
from typing import List
from typing import Tuple
-
+import logging
import torch
import torch.nn as nn
+import numpy as np
+
from funasr.modules.streaming_utils import utils as myutils
from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
from typeguard import check_argument_types
@@ -136,6 +138,9 @@
sanm_shfit: int = None,
concat_embeds: bool = False,
attention_dim: int = None,
+ tf2torch_tensor_name_prefix_torch: str = "decoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+ embed_tensor_name_prefix_tf: str = None,
):
assert check_argument_types()
super().__init__(
@@ -241,6 +246,9 @@
else:
self.embed_concat_ffn = None
self.concat_embeds = concat_embeds
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+ self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
def forward(
self,
@@ -382,6 +390,387 @@
return y, new_cache
+ def gen_tf2torch_map_dict(self):
+
+ tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+ tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
+ map_dict_local = {
+
+ ## decoder
+ # ffn
+ "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # fsmn
+ "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
+ tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
+ tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
+ tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 2, 0),
+ }, # (256,1,31),(1,31,256,1)
+ # src att
+ "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ # dnn
+ "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # embed_concat_ffn
+ "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # out norm
+ "{}.after_norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.after_norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+
+ # in embed
+ "{}.embed.0.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (4235,256),(4235,256)
+
+ # out layer
+ "{}.output_layer.weight".format(tensor_name_prefix_torch):
+ {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
+ "{}/w_embs".format(embed_tensor_name_prefix_tf)],
+ "squeeze": [None, None],
+ "transpose": [(1, 0), None],
+ }, # (4235,256),(256,4235)
+ "{}.output_layer.bias".format(tensor_name_prefix_torch):
+ {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
+ "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
+ "squeeze": [None, None],
+ "transpose": [None, None],
+ }, # (4235,),(4235,)
+
+ }
+ return map_dict_local
+
+ def convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch,
+ ):
+
+ map_dict = self.gen_tf2torch_map_dict()
+ var_dict_torch_update = dict()
+ decoder_layeridx_sets = set()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ names = name.split('.')
+ if names[0] == self.tf2torch_tensor_name_prefix_torch:
+ if names[1] == "decoders":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "decoders2":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ name_q = name_q.replace("decoders2", "decoders")
+ layeridx_bias = len(decoder_layeridx_sets)
+
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "decoders3":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "embed" or names[1] == "output_layer":
+ name_tf = map_dict[name]["name"]
+ if isinstance(name_tf, list):
+ idx_list = 0
+ if name_tf[idx_list] in var_dict_tf.keys():
+ pass
+ else:
+ idx_list = 1
+ data_tf = var_dict_tf[name_tf[idx_list]]
+ if map_dict[name]["squeeze"][idx_list] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
+ if map_dict[name]["transpose"][idx_list] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
+ name_tf[idx_list],
+ var_dict_tf[name_tf[
+ idx_list]].shape))
+
+ else:
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+ if map_dict[name]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "after_norm":
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "embed_concat_ffn":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ return var_dict_torch_update
+
+
class ParaformerSANMDecoder(BaseTransformerDecoder):
"""
author: Speech Lab, Alibaba Group, China
@@ -407,6 +796,8 @@
att_layer_num: int = 6,
kernel_size: int = 21,
sanm_shfit: int = 0,
+ tf2torch_tensor_name_prefix_torch: str = "decoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
):
assert check_argument_types()
super().__init__(
@@ -496,6 +887,8 @@
concat_after,
),
)
+ self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+ self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
def forward(
self,
@@ -613,4 +1006,381 @@
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
- return y, new_cache
\ No newline at end of file
+ return y, new_cache
+
+ def gen_tf2torch_map_dict(self):
+
+ tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+ tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+ map_dict_local = {
+
+ ## decoder
+ # ffn
+ "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # fsmn
+ "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
+ tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
+ tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
+ tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 2, 0),
+ }, # (256,1,31),(1,31,256,1)
+ # src att
+ "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,256),(1,256,256)
+ "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ # dnn
+ "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # embed_concat_ffn
+ "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (1024,256),(1,256,1024)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (1024,),(1024,)
+ "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
+ "squeeze": 0,
+ "transpose": (1, 0),
+ }, # (256,1024),(1,1024,256)
+
+ # out norm
+ "{}.after_norm.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+ "{}.after_norm.bias".format(tensor_name_prefix_torch):
+ {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (256,),(256,)
+
+ # in embed
+ "{}.embed.0.weight".format(tensor_name_prefix_torch):
+ {"name": "{}/w_embs".format(tensor_name_prefix_tf),
+ "squeeze": None,
+ "transpose": None,
+ }, # (4235,256),(4235,256)
+
+ # out layer
+ "{}.output_layer.weight".format(tensor_name_prefix_torch):
+ {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
+ "squeeze": [None, None],
+ "transpose": [(1, 0), None],
+ }, # (4235,256),(256,4235)
+ "{}.output_layer.bias".format(tensor_name_prefix_torch):
+ {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
+ "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
+ "squeeze": [None, None],
+ "transpose": [None, None],
+ }, # (4235,),(4235,)
+
+ }
+ return map_dict_local
+
+ def convert_tf2torch(self,
+ var_dict_tf,
+ var_dict_torch,
+ ):
+ map_dict = self.gen_tf2torch_map_dict()
+ var_dict_torch_update = dict()
+ decoder_layeridx_sets = set()
+ for name in sorted(var_dict_torch.keys(), reverse=False):
+ names = name.split('.')
+ if names[0] == self.tf2torch_tensor_name_prefix_torch:
+ if names[1] == "decoders":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "decoders2":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+ name_q = name_q.replace("decoders2", "decoders")
+ layeridx_bias = len(decoder_layeridx_sets)
+
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "decoders3":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "embed" or names[1] == "output_layer":
+ name_tf = map_dict[name]["name"]
+ if isinstance(name_tf, list):
+ idx_list = 0
+ if name_tf[idx_list] in var_dict_tf.keys():
+ pass
+ else:
+ idx_list = 1
+ data_tf = var_dict_tf[name_tf[idx_list]]
+ if map_dict[name]["squeeze"][idx_list] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
+ if map_dict[name]["transpose"][idx_list] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
+ name_tf[idx_list],
+ var_dict_tf[name_tf[
+ idx_list]].shape))
+
+ else:
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+ if map_dict[name]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "after_norm":
+ name_tf = map_dict[name]["name"]
+ data_tf = var_dict_tf[name_tf]
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+ var_dict_tf[name_tf].shape))
+
+ elif names[1] == "embed_concat_ffn":
+ layeridx = int(names[2])
+ name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+
+ layeridx_bias = 0
+ layeridx += layeridx_bias
+ if "decoders." in name:
+ decoder_layeridx_sets.add(layeridx)
+ if name_q in map_dict.keys():
+ name_v = map_dict[name_q]["name"]
+ name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+ data_tf = var_dict_tf[name_tf]
+ if map_dict[name_q]["squeeze"] is not None:
+ data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+ if map_dict[name_q]["transpose"] is not None:
+ data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+ data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+ assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+ var_dict_torch[
+ name].size(),
+ data_tf.size())
+ var_dict_torch_update[name] = data_tf
+ logging.info(
+ "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+ var_dict_tf[name_tf].shape))
+
+ return var_dict_torch_update
--
Gitblit v1.9.1