speech_asr
2023-03-10 6052e1e7c23c43d495cb5689d6d17450d2d8eb8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import collections
import sys
 
from torch import multiprocessing
 
 
def get_size(obj, seen=None):
    """Recursively finds size of objects
 
    Taken from https://github.com/bosswissam/pysize
 
    """
 
    size = sys.getsizeof(obj)
    if seen is None:
        seen = set()
 
    obj_id = id(obj)
    if obj_id in seen:
        return 0
 
    # Important mark as seen *before* entering recursion to gracefully handle
    # self-referential objects
    seen.add(obj_id)
 
    if isinstance(obj, dict):
        size += sum([get_size(v, seen) for v in obj.values()])
        size += sum([get_size(k, seen) for k in obj.keys()])
    elif hasattr(obj, "__dict__"):
        size += get_size(obj.__dict__, seen)
    elif isinstance(obj, (list, set, tuple)):
        size += sum([get_size(i, seen) for i in obj])
 
    return size
 
 
class SizedDict(collections.abc.MutableMapping):
    def __init__(self, shared: bool = False, data: dict = None):
        if data is None:
            data = {}
 
        if shared:
            # NOTE(kamo): Don't set manager as a field because Manager, which includes
            # weakref object, causes following error with method="spawn",
            # "TypeError: can't pickle weakref objects"
            self.cache = multiprocessing.Manager().dict(**data)
        else:
            self.manager = None
            self.cache = dict(**data)
        self.size = 0
 
    def __setitem__(self, key, value):
        if key in self.cache:
            self.size -= get_size(self.cache[key])
        else:
            self.size += sys.getsizeof(key)
        self.size += get_size(value)
        self.cache[key] = value
 
    def __getitem__(self, key):
        return self.cache[key]
 
    def __delitem__(self, key):
        self.size -= get_size(self.cache[key])
        self.size -= sys.getsizeof(key)
        del self.cache[key]
 
    def __iter__(self):
        return iter(self.cache)
 
    def __contains__(self, key):
        return key in self.cache
 
    def __len__(self):
        return len(self.cache)