From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update
---
funasr/register.py | 18 +++++++++++-------
1 files changed, 11 insertions(+), 7 deletions(-)
diff --git a/funasr/register.py b/funasr/register.py
index 15363c0..45e2a85 100644
--- a/funasr/register.py
+++ b/funasr/register.py
@@ -1,7 +1,7 @@
import logging
import inspect
from dataclasses import dataclass
-
+import re
@dataclass
class RegisterTables:
@@ -15,6 +15,7 @@
predictor_classes = {}
stride_conv_classes = {}
tokenizer_classes = {}
+ dataloader_classes = {}
batch_sampler_classes = {}
dataset_classes = {}
index_ds_classes = {}
@@ -29,7 +30,7 @@
flag = key in classes_key
if classes_key.endswith("_meta") and flag:
print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
- headers = ["class name", "register name", "class location"]
+ headers = ["register name", "class name", "class location"]
metas = []
for register_key, meta in classes_dict.items():
metas.append(meta)
@@ -51,10 +52,9 @@
registry = getattr(self, register_tables_key)
registry_key = key if key is not None else target_class.__name__
- registry_key = registry_key.lower()
- # import pdb; pdb.set_trace()
- assert not registry_key in registry, "(key: {} / class: {}) has been registered already锛宨n {}".format(
- registry_key, target_class, register_tables_key)
+
+ # assert not registry_key in registry, "(key: {} / class: {}) has been registered already锛宨n {}".format(
+ # registry_key, target_class, register_tables_key)
registry[registry_key] = target_class
@@ -63,9 +63,13 @@
if not hasattr(self, register_tables_key_meta):
setattr(self, register_tables_key_meta, {})
registry_meta = getattr(self, register_tables_key_meta)
+ # doc = target_class.__doc__
class_file = inspect.getfile(target_class)
class_line = inspect.getsourcelines(target_class)[1]
- meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
+ pattern = r'^.+/funasr/'
+ class_file = re.sub(pattern, 'funasr/', class_file)
+ # meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
+ meata_data = [f"{registry_key}", f"{target_class.__name__}", f"{class_file}:{class_line}"]
registry_meta[registry_key] = meata_data
# print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
return target_class
--
Gitblit v1.9.1