speech_asr
2023-03-10 e1dc7e311e3b469eeb75bc0d59e1e24e0d5bb7e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
np.set_printoptions(threshold=np.inf)
import logging
 
def load_ckpt(checkpoint_path):
    import tensorflow as tf
    if tf.__version__.startswith('2'):
        import tensorflow.compat.v1 as tf
        tf.disable_v2_behavior()
        reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
    else:
        from tensorflow.python import pywrap_tensorflow
        reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
 
    var_dict = dict()
    for var_name in sorted(var_to_shape_map):
        if "Adam" in var_name:
            continue
        tensor = reader.get_tensor(var_name)
        # print("in ckpt: {}, {}".format(var_name, tensor.shape))
        # print(tensor)
        var_dict[var_name] = tensor
 
    return var_dict
 
 
 
def load_tf_pb_dict(pb_model):
    import tensorflow as tf
    if tf.__version__.startswith('2'):
        import tensorflow.compat.v1 as tf
        tf.disable_v2_behavior()
        # import tensorflow_addons as tfa
        # from tensorflow_addons.seq2seq.python.ops import beam_search_ops
    else:
        from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
    from tensorflow.python.ops import lookup_ops as lookup
    from tensorflow.python.framework import tensor_util
    from tensorflow.python.platform import gfile
    
    sess = tf.Session()
    with gfile.FastGFile(pb_model, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
    
    var_dict = dict()
    for node in sess.graph_def.node:
        if node.op == 'Const':
            value = tensor_util.MakeNdarray(node.attr['value'].tensor)
            if len(value.shape) >= 1:
                var_dict[node.name] = value
    return var_dict
 
def load_tf_dict(pb_model):
    if "model.ckpt-" in pb_model:
        var_dict = load_ckpt(pb_model)
    else:
        var_dict = load_tf_pb_dict(pb_model)
    return var_dict