From e6fe6065770ac92803913c949ab0a53e613064ed Mon Sep 17 00:00:00 2001
From: Vignesh Skanda <agvskanda@gmail.com>
Date: 星期三, 16 十月 2024 13:45:22 +0800
Subject: [PATCH] Update register.py (#2145)

---
 funasr/register.py |   40 +++++++++++++++-------------------------
 1 files changed, 15 insertions(+), 25 deletions(-)

diff --git a/funasr/register.py b/funasr/register.py
index 384bc58..9a04ef2 100644
--- a/funasr/register.py
+++ b/funasr/register.py
@@ -3,9 +3,9 @@
 from dataclasses import dataclass
 import re
 
-
 @dataclass
 class RegisterTables:
+    """Registry system for classes."""
     model_classes = {}
     frontend_classes = {}
     specaug_classes = {}
@@ -21,17 +21,14 @@
     dataset_classes = {}
     index_ds_classes = {}
 
-    def print(self, key=None):
+    def print(self, key: str = None) -> None:
+        """Print registered classes."""
         print("\ntables: \n")
         fields = vars(self)
+        headers = ["register name", "class name", "class location"]
         for classes_key, classes_dict in fields.items():
-
-            flag = True
-            if key is not None:
-                flag = key in classes_key
-            if classes_key.endswith("_meta") and flag:
+            if classes_key.endswith("_meta") and (key is None or key in classes_key):
                 print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
-                headers = ["register name", "class name", "class location"]
                 metas = []
                 for register_key, meta in classes_dict.items():
                     metas.append(meta)
@@ -47,45 +44,38 @@
                     )
         print("\n")
 
-    def register(self, register_tables_key: str, key=None):
+    def register(self, register_tables_key: str, key: str = None) -> callable:
+        """Decorator to register a class."""
         def decorator(target_class):
-
             if not hasattr(self, register_tables_key):
                 setattr(self, register_tables_key, {})
-                logging.info("new registry table has been added: {}".format(register_tables_key))
+                logging.info(f"New registry table added: {register_tables_key}")
 
             registry = getattr(self, register_tables_key)
             registry_key = key if key is not None else target_class.__name__
 
-            # assert not registry_key in registry, "(key: {} / class: {}) has been registered already锛宨n {}".format(
-            #     registry_key, target_class, register_tables_key)
+            if registry_key in registry:
+                raise ValueError(f"Key {registry_key} already exists in {register_tables_key}")
 
             registry[registry_key] = target_class
 
-            # meta锛� headers = ["class name", "register name", "class location"]
             register_tables_key_meta = register_tables_key + "_meta"
             if not hasattr(self, register_tables_key_meta):
                 setattr(self, register_tables_key_meta, {})
             registry_meta = getattr(self, register_tables_key_meta)
-            # doc = target_class.__doc__
+
             class_file = inspect.getfile(target_class)
             class_line = inspect.getsourcelines(target_class)[1]
             pattern = r"^.+/funasr/"
             class_file = re.sub(pattern, "funasr/", class_file)
-            # meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
-            meata_data = [
-                f"{registry_key}",
-                f"{target_class.__name__}",
+            meta_data = [
+                registry_key,
+                target_class.__name__,
                 f"{class_file}:{class_line}",
             ]
-            registry_meta[registry_key] = meata_data
-            # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
+            registry_meta[registry_key] = meta_data
             return target_class
 
         return decorator
 
-
 tables = RegisterTables()
-
-
-import funasr

--
Gitblit v1.9.1