1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import logging
import inspect
from dataclasses import dataclass
import re
 
 
@dataclass
class RegisterTables:
    """Registry system for classes."""
 
    model_classes = {}
    frontend_classes = {}
    specaug_classes = {}
    normalize_classes = {}
    encoder_classes = {}
    decoder_classes = {}
    joint_network_classes = {}
    predictor_classes = {}
    stride_conv_classes = {}
    tokenizer_classes = {}
    dataloader_classes = {}
    batch_sampler_classes = {}
    dataset_classes = {}
    index_ds_classes = {}
 
    def print(self, key: str = None) -> None:
        """Print registered classes."""
        print("\ntables: \n")
        fields = vars(self)
        headers = ["register name", "class name", "class location"]
        for classes_key, classes_dict in fields.items():
            if classes_key.endswith("_meta") and (key is None or key in classes_key):
                print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
                metas = []
                for register_key, meta in classes_dict.items():
                    metas.append(meta)
                metas.sort(key=lambda x: x[0])
                data = [headers] + metas
                col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
 
                for row in data:
                    print(
                        "| "
                        + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
                        + " |"
                    )
        print("\n")
 
    def register(self, register_tables_key: str, key: str = None) -> callable:
        """Decorator to register a class."""
 
        def decorator(target_class):
            if not hasattr(self, register_tables_key):
                setattr(self, register_tables_key, {})
                logging.debug(f"New registry table added: {register_tables_key}")
 
            registry = getattr(self, register_tables_key)
            registry_key = key if key is not None else target_class.__name__
 
            if registry_key in registry:
                logging.debug(
                    f"Key {registry_key} already exists in {register_tables_key}, re-register"
                )
 
            registry[registry_key] = target_class
 
            register_tables_key_meta = register_tables_key + "_meta"
            if not hasattr(self, register_tables_key_meta):
                setattr(self, register_tables_key_meta, {})
            registry_meta = getattr(self, register_tables_key_meta)
 
            class_file = inspect.getfile(target_class)
            class_line = inspect.getsourcelines(target_class)[1]
            pattern = r"^.+/funasr/"
            class_file = re.sub(pattern, "funasr/", class_file)
            meta_data = [
                registry_key,
                target_class.__name__,
                f"{class_file}:{class_line}",
            ]
            registry_meta[registry_key] = meta_data
            return target_class
 
        return decorator
 
 
tables = RegisterTables()