| | |
| | | |
| | | return y, new_cache |
| | | |
| | | def gen_tf2torch_map_dict(self): |
| | | |
| | | tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch |
| | | tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf |
| | | embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf |
| | | map_dict_local = { |
| | | |
| | | ## decoder |
| | | # ffn |
| | | "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | |
| | | # fsmn |
| | | "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format( |
| | | tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 2, 0), |
| | | }, # (256,1,31),(1,31,256,1) |
| | | # src att |
| | | "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | # dnn |
| | | "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | |
| | | # embed_concat_ffn |
| | | "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | |
| | | # out norm |
| | | "{}.after_norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.after_norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | |
| | | # in embed |
| | | "{}.embed.0.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (4235,256),(4235,256) |
| | | |
| | | # out layer |
| | | "{}.output_layer.weight".format(tensor_name_prefix_torch): |
| | | {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), |
| | | "{}/w_embs".format(embed_tensor_name_prefix_tf)], |
| | | "squeeze": [None, None], |
| | | "transpose": [(1, 0), None], |
| | | }, # (4235,256),(256,4235) |
| | | "{}.output_layer.bias".format(tensor_name_prefix_torch): |
| | | {"name": ["{}/dense/bias".format(tensor_name_prefix_tf), |
| | | "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"], |
| | | "squeeze": [None, None], |
| | | "transpose": [None, None], |
| | | }, # (4235,),(4235,) |
| | | |
| | | } |
| | | return map_dict_local |
| | | |
| | | def convert_tf2torch(self, |
| | | var_dict_tf, |
| | | var_dict_torch, |
| | | ): |
| | | |
| | | map_dict = self.gen_tf2torch_map_dict() |
| | | var_dict_torch_update = dict() |
| | | decoder_layeridx_sets = set() |
| | | for name in sorted(var_dict_torch.keys(), reverse=False): |
| | | names = name.split('.') |
| | | if names[0] == self.tf2torch_tensor_name_prefix_torch: |
| | | if names[1] == "decoders": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "decoders2": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | name_q = name_q.replace("decoders2", "decoders") |
| | | layeridx_bias = len(decoder_layeridx_sets) |
| | | |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "decoders3": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "embed" or names[1] == "output_layer": |
| | | name_tf = map_dict[name]["name"] |
| | | if isinstance(name_tf, list): |
| | | idx_list = 0 |
| | | if name_tf[idx_list] in var_dict_tf.keys(): |
| | | pass |
| | | else: |
| | | idx_list = 1 |
| | | data_tf = var_dict_tf[name_tf[idx_list]] |
| | | if map_dict[name]["squeeze"][idx_list] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list]) |
| | | if map_dict[name]["transpose"][idx_list] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), |
| | | name_tf[idx_list], |
| | | var_dict_tf[name_tf[ |
| | | idx_list]].shape)) |
| | | |
| | | else: |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"]) |
| | | if map_dict[name]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "after_norm": |
| | | name_tf = map_dict[name]["name"] |
| | | data_tf = var_dict_tf[name_tf] |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "embed_concat_ffn": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | if "decoders." in name: |
| | | decoder_layeridx_sets.add(layeridx) |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |
| | | |