huangmingming
2023-01-30 adcee8828ef5d78b575043954deb662a35e318f7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import logging
from typing import Dict
import sys
import torch
 
 
def load_ckpt(checkpoint_path: str) -> Dict[str, np.ndarray]:
    from tensorflow.python import pywrap_tensorflow
 
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
 
    var_dict = dict()
    for var_name in sorted(var_to_shape_map):
        if "optimizer" in var_name:
            continue
        tensor = reader.get_tensor(var_name)
        print("in ckpt: {}, {}".format(var_name, tensor.shape))  # print(tensor)
        var_dict[var_name] = tensor
 
    return var_dict
 
 
def convert_parameter_name_for_asv_resnet34(
        var_dict: Dict[str, np.ndarray],
        old_prefix: str = "EAND/speech_encoder",
        new_prefix: str = "encoder",
        train_steps: int = 0
) -> Dict[str, np.ndarray]:
    new_dict = dict()
    model_size = 0
    for name, tensor in var_dict.items():
        if not name.startswith(old_prefix):
            if name == "softmax/output/kernel":
                new_name = "decoder.output_dense.weight"
                tensor = np.transpose(tensor, [1, 0])
                new_dict[new_name] = torch.Tensor(tensor)
            continue
        new_name = name.replace(old_prefix, new_prefix)
        new_name = new_name.replace("/", ".")
        if "resnet1" in new_name or "resnet2" in new_name:
            new_name = new_name.replace("encoder", "decoder")
        module_name, para_name = new_name.rsplit(".", 1)
        # process for batch normalization
        if "bn" in module_name:
            new_name = new_name.replace("gamma", "weight")
            new_name = new_name.replace("beta", "bias")
            new_name = new_name.replace("moving_mean", "running_mean")
            new_name = new_name.replace("moving_variance", "running_var")
 
            new_dict[new_name] = torch.Tensor(tensor)
            new_dict[module_name + ".num_batches_tracked"] = torch.Tensor(train_steps)
 
        # process for dense layers
        elif "dense" in module_name:
            new_name = new_name.replace("kernel", "weight")
            if para_name == "kernel":
                if len(tensor.shape) == 2:
                    tensor = np.transpose(tensor, [1, 0])
                elif len(tensor.shape) == 3:
                    tensor = np.transpose(tensor, [2, 1, 0])
                # for dense0
                elif len(tensor.shape) == 4:
                    tensor = np.transpose(tensor, [3, 2, 0, 1])
 
            new_dict[new_name] = torch.Tensor(tensor)
 
        # process for conv layers
        elif "conv" in module_name:
            new_name = new_name.replace("kernel", "weight")
            if para_name == "kernel":
                tensor = np.transpose(tensor, [3, 2, 0, 1])
 
            new_dict[new_name] = torch.Tensor(tensor)
 
        print("{} -> {}".format(name, new_name))
        model_size += new_dict[new_name].numel()
    print("Model size: {}".format(model_size))
    return new_dict
 
 
if __name__ == '__main__':
    checkpoint_path = sys.argv[1]
    pkl_path = sys.argv[2]
    tf_dict = load_ckpt(checkpoint_path)
    torch_dict = convert_parameter_name_for_asv_resnet34(
        tf_dict,
        train_steps=300000,
    )
    torch.save(
        torch_dict, pkl_path
    )