import numpy as np import logging from typing import Dict import sys import torch def load_ckpt(checkpoint_path: str) -> Dict[str, np.ndarray]: from tensorflow.python import pywrap_tensorflow reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() var_dict = dict() for var_name in sorted(var_to_shape_map): if "optimizer" in var_name: continue tensor = reader.get_tensor(var_name) print("in ckpt: {}, {}".format(var_name, tensor.shape)) # print(tensor) var_dict[var_name] = tensor return var_dict def convert_parameter_name_for_asv_resnet34( var_dict: Dict[str, np.ndarray], old_prefix: str = "EAND/speech_encoder", new_prefix: str = "encoder", train_steps: int = 0 ) -> Dict[str, np.ndarray]: new_dict = dict() model_size = 0 for name, tensor in var_dict.items(): if not name.startswith(old_prefix): if name == "softmax/output/kernel": new_name = "decoder.output_dense.weight" tensor = np.transpose(tensor, [1, 0]) new_dict[new_name] = torch.Tensor(tensor) continue new_name = name.replace(old_prefix, new_prefix) new_name = new_name.replace("/", ".") if "resnet1" in new_name or "resnet2" in new_name: new_name = new_name.replace("encoder", "decoder") module_name, para_name = new_name.rsplit(".", 1) # process for batch normalization if "bn" in module_name: new_name = new_name.replace("gamma", "weight") new_name = new_name.replace("beta", "bias") new_name = new_name.replace("moving_mean", "running_mean") new_name = new_name.replace("moving_variance", "running_var") new_dict[new_name] = torch.Tensor(tensor) new_dict[module_name + ".num_batches_tracked"] = torch.Tensor(train_steps) # process for dense layers elif "dense" in module_name: new_name = new_name.replace("kernel", "weight") if para_name == "kernel": if len(tensor.shape) == 2: tensor = np.transpose(tensor, [1, 0]) elif len(tensor.shape) == 3: tensor = np.transpose(tensor, [2, 1, 0]) # for dense0 elif len(tensor.shape) == 4: tensor = np.transpose(tensor, [3, 2, 0, 1]) new_dict[new_name] = torch.Tensor(tensor) # process for conv layers elif "conv" in module_name: new_name = new_name.replace("kernel", "weight") if para_name == "kernel": tensor = np.transpose(tensor, [3, 2, 0, 1]) new_dict[new_name] = torch.Tensor(tensor) print("{} -> {}".format(name, new_name)) model_size += new_dict[new_name].numel() print("Model size: {}".format(model_size)) return new_dict if __name__ == '__main__': checkpoint_path = sys.argv[1] pkl_path = sys.argv[2] tf_dict = load_ckpt(checkpoint_path) torch_dict = convert_parameter_name_for_asv_resnet34( tf_dict, train_steps=300000, ) torch.save( torch_dict, pkl_path )