| | |
| | | |
| | | return xs_pad, ilens, None |
| | | |
| | | def gen_tf2torch_map_dict(self): |
| | | tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch |
| | | tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf |
| | | map_dict_local = { |
| | | ## encoder |
| | | # cicd |
| | | "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (768,256),(1,256,768) |
| | | "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (768,),(768,) |
| | | "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 2, 0), |
| | | }, # (256,1,31),(1,31,256,1) |
| | | "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,256),(1,256,256) |
| | | "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | # ffn |
| | | "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (1024,256),(1,256,1024) |
| | | "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (1024,),(1024,) |
| | | "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf), |
| | | "squeeze": 0, |
| | | "transpose": (1, 0), |
| | | }, # (256,1024),(1,1024,256) |
| | | "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | # out norm |
| | | "{}.after_norm.weight".format(tensor_name_prefix_torch): |
| | | {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | "{}.after_norm.bias".format(tensor_name_prefix_torch): |
| | | {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf), |
| | | "squeeze": None, |
| | | "transpose": None, |
| | | }, # (256,),(256,) |
| | | |
| | | } |
| | | |
| | | return map_dict_local |
| | | |
| | | def convert_tf2torch(self, |
| | | var_dict_tf, |
| | | var_dict_torch, |
| | | ): |
| | | |
| | | map_dict = self.gen_tf2torch_map_dict() |
| | | |
| | | var_dict_torch_update = dict() |
| | | for name in sorted(var_dict_torch.keys(), reverse=False): |
| | | names = name.split('.') |
| | | if names[0] == self.tf2torch_tensor_name_prefix_torch: |
| | | if names[1] == "encoders0": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | |
| | | name_q = name_q.replace("encoders0", "encoders") |
| | | layeridx_bias = 0 |
| | | layeridx += layeridx_bias |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | elif names[1] == "encoders": |
| | | layeridx = int(names[2]) |
| | | name_q = name.replace(".{}.".format(layeridx), ".layeridx.") |
| | | layeridx_bias = 1 |
| | | layeridx += layeridx_bias |
| | | if name_q in map_dict.keys(): |
| | | name_v = map_dict[name_q]["name"] |
| | | name_tf = name_v.replace("layeridx", "{}".format(layeridx)) |
| | | data_tf = var_dict_tf[name_tf] |
| | | if map_dict[name_q]["squeeze"] is not None: |
| | | data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"]) |
| | | if map_dict[name_q]["transpose"] is not None: |
| | | data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"]) |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf, |
| | | var_dict_torch[ |
| | | name].size(), |
| | | data_tf.size()) |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | elif names[1] == "after_norm": |
| | | name_tf = map_dict[name]["name"] |
| | | data_tf = var_dict_tf[name_tf] |
| | | data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu") |
| | | var_dict_torch_update[name] = data_tf |
| | | logging.info( |
| | | "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf, |
| | | var_dict_tf[name_tf].shape)) |
| | | |
| | | return var_dict_torch_update |
| | | |