From 2a0b2c795b161a0bd56e026c53eb605fea9e142c Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 15 一月 2024 11:51:26 +0800
Subject: [PATCH] funasr1.0

---
 funasr/bin/inference.py                    |    8 ++--
 funasr/models/paraformer/model.py          |   10 ++--
 funasr/models/seaco_paraformer/model.py    |    2 
 funasr/models/fsmn_vad_streaming/model.py  |    2 
 funasr/bin/train.py                        |   10 ++--
 funasr/register.py                         |   13 ++++--
 funasr/models/ct_transformer/model.py      |    2 
 funasr/models/monotonic_aligner/model.py   |    8 ++--
 funasr/models/transformer/model.py         |   10 ++--
 funasr/datasets/audio_datasets/datasets.py |    9 +++-
 10 files changed, 40 insertions(+), 34 deletions(-)

diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index 3aab31a..48957dd 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -175,7 +175,7 @@
         # build tokenizer
         tokenizer = kwargs.get("tokenizer", None)
         if tokenizer is not None:
-            tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
+            tokenizer_class = tables.tokenizer_classes.get(tokenizer)
             tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
             kwargs["tokenizer"] = tokenizer
             kwargs["token_list"] = tokenizer.token_list
@@ -186,13 +186,13 @@
         # build frontend
         frontend = kwargs.get("frontend", None)
         if frontend is not None:
-            frontend_class = tables.frontend_classes.get(frontend.lower())
+            frontend_class = tables.frontend_classes.get(frontend)
             frontend = frontend_class(**kwargs["frontend_conf"])
             kwargs["frontend"] = frontend
             kwargs["input_size"] = frontend.output_size()
         
         # build model
-        model_class = tables.model_classes.get(kwargs["model"].lower())
+        model_class = tables.model_classes.get(kwargs["model"])
         model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
         model.eval()
         model.to(device)
@@ -443,7 +443,7 @@
         # build frontend
         frontend = kwargs.get("frontend", None)
         if frontend is not None:
-            frontend_class = tables.frontend_classes.get(frontend.lower())
+            frontend_class = tables.frontend_classes.get(frontend)
             frontend = frontend_class(**kwargs["frontend_conf"])
 
         self.frontend = frontend
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index af3e8af..878eb24 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -64,14 +64,14 @@
 
 	tokenizer = kwargs.get("tokenizer", None)
 	if tokenizer is not None:
-		tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
+		tokenizer_class = tables.tokenizer_classes.get(tokenizer)
 		tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
 		kwargs["tokenizer"] = tokenizer
 	
 	# build frontend if frontend is none None
 	frontend = kwargs.get("frontend", None)
 	if frontend is not None:
-		frontend_class = tables.frontend_classes.get(frontend.lower())
+		frontend_class = tables.frontend_classes.get(frontend)
 		frontend = frontend_class(**kwargs["frontend_conf"])
 		kwargs["frontend"] = frontend
 		kwargs["input_size"] = frontend.output_size()
@@ -79,7 +79,7 @@
 	# import pdb;
 	# pdb.set_trace()
 	# build model
-	model_class = tables.model_classes.get(kwargs["model"].lower())
+	model_class = tables.model_classes.get(kwargs["model"])
 	model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
 
 
@@ -141,12 +141,12 @@
 	# import pdb;
 	# pdb.set_trace()
 	# dataset
-	dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower())
+	dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
 	dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
 
 	# dataloader
 	batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
-	batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler.lower())
+	batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
 	if batch_sampler is not None:
 		batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
 	dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
index ff82856..7839ff9 100644
--- a/funasr/datasets/audio_datasets/datasets.py
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -13,6 +13,9 @@
 
 @tables.register("dataset_classes", "AudioDataset")
 class AudioDataset(torch.utils.data.Dataset):
+	"""
+	AudioDataset
+	"""
 	def __init__(self,
 	             path,
 	             index_ds: str = None,
@@ -22,16 +25,16 @@
 	             float_pad_value: float = 0.0,
 	              **kwargs):
 		super().__init__()
-		index_ds_class = tables.index_ds_classes.get(index_ds.lower())
+		index_ds_class = tables.index_ds_classes.get(index_ds)
 		self.index_ds = index_ds_class(path)
 		preprocessor_speech = kwargs.get("preprocessor_speech", None)
 		if preprocessor_speech:
-			preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
+			preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
 			preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
 		self.preprocessor_speech = preprocessor_speech
 		preprocessor_text = kwargs.get("preprocessor_text", None)
 		if preprocessor_text:
-			preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text.lower())
+			preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
 			preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
 		self.preprocessor_text = preprocessor_text
 		
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index 7187f45..5fb3ed4 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -46,7 +46,7 @@
         
         
         self.embed = nn.Embedding(vocab_size, embed_unit)
-        encoder_class = tables.encoder_classes.get(encoder.lower())
+        encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(**encoder_conf)
 
         self.decoder = nn.Linear(att_unit, punc_size)
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 544fab8..c87558c 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -268,7 +268,7 @@
         super().__init__()
         self.vad_opts = VADXOptions(**kwargs)
 
-        encoder_class = tables.encoder_classes.get(encoder.lower())
+        encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(**encoder_conf)
         self.encoder = encoder
 
diff --git a/funasr/models/monotonic_aligner/model.py b/funasr/models/monotonic_aligner/model.py
index 1b43c2f..6309732 100644
--- a/funasr/models/monotonic_aligner/model.py
+++ b/funasr/models/monotonic_aligner/model.py
@@ -41,15 +41,15 @@
         super().__init__()
 
         if specaug is not None:
-            specaug_class = tables.specaug_classes.get(specaug.lower())
+            specaug_class = tables.specaug_classes.get(specaug)
             specaug = specaug_class(**specaug_conf)
         if normalize is not None:
-            normalize_class = tables.normalize_classes.get(normalize.lower())
+            normalize_class = tables.normalize_classes.get(normalize)
             normalize = normalize_class(**normalize_conf)
-        encoder_class = tables.encoder_classes.get(encoder.lower())
+        encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(input_size=input_size, **encoder_conf)
         encoder_output_size = encoder.output_size()
-        predictor_class = tables.predictor_classes.get(predictor.lower())
+        predictor_class = tables.predictor_classes.get(predictor)
         predictor = predictor_class(**predictor_conf)
         self.specaug = specaug
         self.normalize = normalize
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index f60bead..2cd9c88 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -79,17 +79,17 @@
         super().__init__()
 
         if specaug is not None:
-            specaug_class = tables.specaug_classes.get(specaug.lower())
+            specaug_class = tables.specaug_classes.get(specaug)
             specaug = specaug_class(**specaug_conf)
         if normalize is not None:
-            normalize_class = tables.normalize_classes.get(normalize.lower())
+            normalize_class = tables.normalize_classes.get(normalize)
             normalize = normalize_class(**normalize_conf)
-        encoder_class = tables.encoder_classes.get(encoder.lower())
+        encoder_class = tables.encoder_classes.get(encoder)
         encoder = encoder_class(input_size=input_size, **encoder_conf)
         encoder_output_size = encoder.output_size()
 
         if decoder is not None:
-            decoder_class = tables.decoder_classes.get(decoder.lower())
+            decoder_class = tables.decoder_classes.get(decoder)
             decoder = decoder_class(
                 vocab_size=vocab_size,
                 encoder_output_size=encoder_output_size,
@@ -104,7 +104,7 @@
                 odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
             )
         if predictor is not None:
-            predictor_class = tables.predictor_classes.get(predictor.lower())
+            predictor_class = tables.predictor_classes.get(predictor)
             predictor = predictor_class(**predictor_conf)
         
         # note that eos is the same as sos (equivalent ID)
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 4f6c176..db5c7dd 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -90,7 +90,7 @@
         seaco_decoder = kwargs.get("seaco_decoder", None)
         if seaco_decoder is not None:
             seaco_decoder_conf = kwargs.get("seaco_decoder_conf")
-            seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower())
+            seaco_decoder_class = tables.decoder_classes.get(seaco_decoder)
             self.seaco_decoder = seaco_decoder_class(
                 vocab_size=self.vocab_size,
                 encoder_output_size=self.inner_dim,
diff --git a/funasr/models/transformer/model.py b/funasr/models/transformer/model.py
index 2f6e15a..e2367a7 100644
--- a/funasr/models/transformer/model.py
+++ b/funasr/models/transformer/model.py
@@ -60,19 +60,19 @@
         super().__init__()
 
         if frontend is not None:
-            frontend_class = tables.frontend_classes.get_class(frontend.lower())
+            frontend_class = tables.frontend_classes.get_class(frontend)
             frontend = frontend_class(**frontend_conf)
         if specaug is not None:
-            specaug_class = tables.specaug_classes.get_class(specaug.lower())
+            specaug_class = tables.specaug_classes.get_class(specaug)
             specaug = specaug_class(**specaug_conf)
         if normalize is not None:
-            normalize_class = tables.normalize_classes.get_class(normalize.lower())
+            normalize_class = tables.normalize_classes.get_class(normalize)
             normalize = normalize_class(**normalize_conf)
-        encoder_class = tables.encoder_classes.get_class(encoder.lower())
+        encoder_class = tables.encoder_classes.get_class(encoder)
         encoder = encoder_class(input_size=input_size, **encoder_conf)
         encoder_output_size = encoder.output_size()
         if decoder is not None:
-            decoder_class = tables.decoder_classes.get_class(decoder.lower())
+            decoder_class = tables.decoder_classes.get_class(decoder)
             decoder = decoder_class(
                 vocab_size=vocab_size,
                 encoder_output_size=encoder_output_size,
diff --git a/funasr/register.py b/funasr/register.py
index 15363c0..454105f 100644
--- a/funasr/register.py
+++ b/funasr/register.py
@@ -1,7 +1,7 @@
 import logging
 import inspect
 from dataclasses import dataclass
-
+import re
 
 @dataclass
 class RegisterTables:
@@ -29,7 +29,7 @@
                 flag = key in classes_key
             if classes_key.endswith("_meta") and flag:
                 print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
-                headers = ["class name", "register name", "class location"]
+                headers = ["class name", "class location"]
                 metas = []
                 for register_key, meta in classes_dict.items():
                     metas.append(meta)
@@ -51,8 +51,7 @@
             
             registry = getattr(self, register_tables_key)
             registry_key = key if key is not None else target_class.__name__
-            registry_key = registry_key.lower()
-            # import pdb; pdb.set_trace()
+            
             assert not registry_key in registry, "(key: {} / class: {}) has been registered already锛宨n {}".format(
                 registry_key, target_class, register_tables_key)
             
@@ -63,9 +62,13 @@
             if not hasattr(self, register_tables_key_meta):
                 setattr(self, register_tables_key_meta, {})
             registry_meta = getattr(self, register_tables_key_meta)
+            # doc = target_class.__doc__
             class_file = inspect.getfile(target_class)
             class_line = inspect.getsourcelines(target_class)[1]
-            meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
+            pattern = r'^.+/funasr/'
+            class_file = re.sub(pattern, 'funasr/', class_file)
+            meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
+            # meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
             registry_meta[registry_key] = meata_data
             # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
             return target_class

--
Gitblit v1.9.1