游雁
2024-01-15 2a0b2c795b161a0bd56e026c53eb605fea9e142c
funasr/register.py
@@ -1,7 +1,7 @@
import logging
import inspect
from dataclasses import dataclass
import re
@dataclass
class RegisterTables:
@@ -29,7 +29,7 @@
                flag = key in classes_key
            if classes_key.endswith("_meta") and flag:
                print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
                headers = ["class name", "register name", "class location"]
                headers = ["class name", "class location"]
                metas = []
                for register_key, meta in classes_dict.items():
                    metas.append(meta)
@@ -51,8 +51,7 @@
            
            registry = getattr(self, register_tables_key)
            registry_key = key if key is not None else target_class.__name__
            registry_key = registry_key.lower()
            # import pdb; pdb.set_trace()
            assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(
                registry_key, target_class, register_tables_key)
            
@@ -63,9 +62,13 @@
            if not hasattr(self, register_tables_key_meta):
                setattr(self, register_tables_key_meta, {})
            registry_meta = getattr(self, register_tables_key_meta)
            # doc = target_class.__doc__
            class_file = inspect.getfile(target_class)
            class_line = inspect.getsourcelines(target_class)[1]
            meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
            pattern = r'^.+/funasr/'
            class_file = re.sub(pattern, 'funasr/', class_file)
            meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
            # meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
            registry_meta[registry_key] = meata_data
            # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
            return target_class