From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/register.py | 62 ++++++++++++++++++-------------
1 files changed, 36 insertions(+), 26 deletions(-)
diff --git a/funasr/register.py b/funasr/register.py
index 145a698..a8b1fd3 100644
--- a/funasr/register.py
+++ b/funasr/register.py
@@ -1,10 +1,13 @@
import logging
import inspect
from dataclasses import dataclass
+import re
@dataclass
class RegisterTables:
+ """Registry system for classes."""
+
model_classes = {}
frontend_classes = {}
specaug_classes = {}
@@ -15,63 +18,70 @@
predictor_classes = {}
stride_conv_classes = {}
tokenizer_classes = {}
+ dataloader_classes = {}
batch_sampler_classes = {}
dataset_classes = {}
index_ds_classes = {}
- def print(self,):
+ def print(self, key: str = None) -> None:
+ """Print registered classes."""
print("\ntables: \n")
fields = vars(self)
+ headers = ["register name", "class name", "class location"]
for classes_key, classes_dict in fields.items():
- print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
-
- if classes_key.endswith("_meta"):
- headers = ["class name", "register name", "class location"]
+ if classes_key.endswith("_meta") and (key is None or key in classes_key):
+ print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
metas = []
for register_key, meta in classes_dict.items():
metas.append(meta)
metas.sort(key=lambda x: x[0])
data = [headers] + metas
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
-
+
for row in data:
- print("| " + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) + " |")
+ print(
+ "| "
+ + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
+ + " |"
+ )
print("\n")
+ def register(self, register_tables_key: str, key: str = None) -> callable:
+ """Decorator to register a class."""
- def register(self, register_tables_key: str, key=None):
def decorator(target_class):
-
if not hasattr(self, register_tables_key):
setattr(self, register_tables_key, {})
- logging.info("new registry table has been added: {}".format(register_tables_key))
-
+ logging.debug(f"New registry table added: {register_tables_key}")
+
registry = getattr(self, register_tables_key)
registry_key = key if key is not None else target_class.__name__
- registry_key = registry_key.lower()
- # import pdb; pdb.set_trace()
- assert not registry_key in registry, "(key: {} / class: {}) has been registered already锛宨n {}".format(
- registry_key, target_class, register_tables_key)
-
+
+ if registry_key in registry:
+ logging.debug(
+ f"Key {registry_key} already exists in {register_tables_key}, re-register"
+ )
+
registry[registry_key] = target_class
-
- # meta锛� headers = ["class name", "register name", "class location"]
+
register_tables_key_meta = register_tables_key + "_meta"
if not hasattr(self, register_tables_key_meta):
setattr(self, register_tables_key_meta, {})
registry_meta = getattr(self, register_tables_key_meta)
+
class_file = inspect.getfile(target_class)
class_line = inspect.getsourcelines(target_class)[1]
- meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
- registry_meta[registry_key] = meata_data
- # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
+ pattern = r"^.+/funasr/"
+ class_file = re.sub(pattern, "funasr/", class_file)
+ meta_data = [
+ registry_key,
+ target_class.__name__,
+ f"{class_file}:{class_line}",
+ ]
+ registry_meta[registry_key] = meta_data
return target_class
-
+
return decorator
tables = RegisterTables()
-
-
-import funasr
-
--
Gitblit v1.9.1