游雁
2024-04-29 2779602177ae5374547c7a7e17de0b11a166326d
funasr/register.py
@@ -3,6 +3,7 @@
from dataclasses import dataclass
import re
@dataclass
class RegisterTables:
    model_classes = {}
@@ -15,6 +16,7 @@
    predictor_classes = {}
    stride_conv_classes = {}
    tokenizer_classes = {}
    dataloader_classes = {}
    batch_sampler_classes = {}
    dataset_classes = {}
    index_ds_classes = {}
@@ -23,7 +25,7 @@
        print("\ntables: \n")
        fields = vars(self)
        for classes_key, classes_dict in fields.items():
            flag = True
            if key is not None:
                flag = key in classes_key
@@ -36,27 +38,30 @@
                metas.sort(key=lambda x: x[0])
                data = [headers] + metas
                col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
                for row in data:
                    print("| " + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) + " |")
        print("\n")
                for row in data:
                    print(
                        "| "
                        + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
                        + " |"
                    )
        print("\n")
    def register(self, register_tables_key: str, key=None):
        def decorator(target_class):
            if not hasattr(self, register_tables_key):
                setattr(self, register_tables_key, {})
                logging.info("new registry table has been added: {}".format(register_tables_key))
            registry = getattr(self, register_tables_key)
            registry_key = key if key is not None else target_class.__name__
            # assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(
            #     registry_key, target_class, register_tables_key)
            registry[registry_key] = target_class
            # meta, headers = ["class name", "register name", "class location"]
            register_tables_key_meta = register_tables_key + "_meta"
            if not hasattr(self, register_tables_key_meta):
@@ -65,14 +70,18 @@
            # doc = target_class.__doc__
            class_file = inspect.getfile(target_class)
            class_line = inspect.getsourcelines(target_class)[1]
            pattern = r'^.+/funasr/'
            class_file = re.sub(pattern, 'funasr/', class_file)
            pattern = r"^.+/funasr/"
            class_file = re.sub(pattern, "funasr/", class_file)
            # meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
            meata_data = [f"{registry_key}", f"{target_class.__name__}", f"{class_file}:{class_line}"]
            meata_data = [
                f"{registry_key}",
                f"{target_class.__name__}",
                f"{class_file}:{class_line}",
            ]
            registry_meta[registry_key] = meata_data
            # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
            return target_class
        return decorator
@@ -80,4 +89,3 @@
import funasr