From 7dadb793e639d2b7f918f2f915e928a63e016ea5 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 23 十一月 2023 16:04:37 +0800
Subject: [PATCH] Dev gzf funasr2 (#1111)

---
 funasr/datasets/small_datasets/preprocessor.py     |    6 
 funasr/datasets/large_datasets/build_dataloader.py |    2 
 funasr/datasets/data_sampler.py                    |   54 +++++---
 funasr/tasks/sa_asr.py                             |    2 
 funasr/export/models/__init__.py                   |    6 
 funasr/tasks/punctuation.py                        |    2 
 funasr/tokenizer/sentencepiece_tokenizer.py        |    2 
 funasr/bin/train.py                                |    2 
 funasr/datasets/preprocessor.py                    |    6 
 funasr/tokenizer/build_tokenizer.py                |   10 
 funasr/datasets/dataloader_fn.py                   |   53 ++++++++
 funasr/datasets/dataset_jsonl.py                   |   89 ++++++++++++++
 funasr/tokenizer/word_tokenizer.py                 |    2 
 funasr/tokenizer/korean_cleaner.py                 |    0 
 funasr/tokenizer/char_tokenizer.py                 |    2 
 funasr/tasks/data2vec.py                           |    2 
 funasr/tasks/lm.py                                 |    2 
 funasr/tasks/whisper.py                            |    2 
 funasr/tasks/asr.py                                |    2 
 funasr/tokenizer/cleaner.py                        |    0 
 funasr/tokenizer/__init__.py                       |    0 
 funasr/models/frontend/wav_frontend.py             |    9 +
 funasr/bin/asr_infer.py                            |    4 
 /dev/null                                          |   69 -----------
 funasr/bin/tp_infer.py                             |    2 
 funasr/tokenizer/abs_tokenizer.py                  |    0 
 funasr/tokenizer/phoneme_tokenizer.py              |    4 
 funasr/tokenizer/token_id_converter.py             |    0 
 funasr/bin/build_trainer.py                        |    2 
 funasr/bin/tokenize_text.py                        |    6 
 30 files changed, 212 insertions(+), 130 deletions(-)

diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index c1d08df..a1cede1 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -34,8 +34,8 @@
 from funasr.modules.scorers.ctc import CTCPrefixScorer
 from funasr.modules.scorers.length_bonus import LengthBonus
 from funasr.build_utils.build_asr_model import frontend_choices
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
+from funasr.tokenizer.build_tokenizer import build_tokenizer
+from funasr.tokenizer.token_id_converter import TokenIDConverter
 from funasr.torch_utils.device_funcs import to_device
 from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
 
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index bda83ec..c03bdf3 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -18,7 +18,7 @@
 from funasr.build_utils.build_scheduler import build_scheduler
 from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope
 from funasr.modules.lora.utils import mark_only_lora_as_trainable
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
diff --git a/funasr/bin/tokenize_text.py b/funasr/bin/tokenize_text.py
index 6ec83a8..674c1b9 100755
--- a/funasr/bin/tokenize_text.py
+++ b/funasr/bin/tokenize_text.py
@@ -9,9 +9,9 @@
 
 
 from funasr.utils.cli_utils import get_commandline_args
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.cleaner import TextCleaner
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.build_tokenizer import build_tokenizer
+from funasr.tokenizer.cleaner import TextCleaner
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
 from funasr.utils.types import str2bool
 from funasr.utils.types import str_or_none
 
diff --git a/funasr/bin/tp_infer.py b/funasr/bin/tp_infer.py
index ede579c..cfe534f 100644
--- a/funasr/bin/tp_infer.py
+++ b/funasr/bin/tp_infer.py
@@ -11,7 +11,7 @@
 import torch
 from funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.text.token_id_converter import TokenIDConverter
+from funasr.tokenizer.token_id_converter import TokenIDConverter
 from funasr.torch_utils.device_funcs import to_device
 
 
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index f5d10c4..6aebf8a 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -17,7 +17,7 @@
 from funasr.build_utils.build_optimizer import build_optimizer
 from funasr.build_utils.build_scheduler import build_scheduler
 from funasr.build_utils.build_trainer import build_trainer
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
diff --git a/funasr/datasets/data_sampler.py b/funasr/datasets/data_sampler.py
index 2875d8d..6b3407c 100644
--- a/funasr/datasets/data_sampler.py
+++ b/funasr/datasets/data_sampler.py
@@ -1,29 +1,42 @@
 import torch
 
+import numpy as np
+
 class BatchSampler(torch.utils.data.BatchSampler):
 	
-	def __init__(self, dataset=None, args=None, drop_last=True, ):
+	def __init__(self, dataset, batch_size_type: str="example", batch_size: int=14, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
 		
 		self.drop_last = drop_last
 		self.pre_idx = -1
 		self.dataset = dataset
-		self.batch_size_type = args.batch_size_type
-		self.batch_size = args.batch_size
-		self.sort_size = args.sort_size
-		self.max_length_token = args.max_length_token
 		self.total_samples = len(dataset)
+		# self.batch_size_type = args.batch_size_type
+		# self.batch_size = args.batch_size
+		# self.sort_size = args.sort_size
+		# self.max_length_token = args.max_length_token
+		self.batch_size_type = batch_size_type
+		self.batch_size = batch_size
+		self.sort_size = sort_size
+		self.max_length_token = kwargs.get("max_length_token", 5000)
+		self.shuffle_idx = np.arange(self.total_samples)
+		self.shuffle = shuffle
 
 	
 	def __len__(self):
 		return self.total_samples
 
-	
 	def __iter__(self):
+		print("in sampler")
+		
+		if self.shuffle:
+			np.random.shuffle(self.shuffle_idx)
+			
 		batch = []
 		max_token = 0
 		num_sample = 0
-		
+
 		iter_num = (self.total_samples-1) // self.sort_size + 1
+		print("iter_num: ", iter_num)
 		for iter in range(self.pre_idx + 1, iter_num):
 			datalen_with_index = []
 			for i in range(self.sort_size):
@@ -31,30 +44,31 @@
 				if idx >= self.total_samples:
 					continue
 
-				if self.batch_size_type == "example":
-					sample_len_cur = 1
-				else:
-					idx_map = self.dataset.shuffle_idx[idx]
-					# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-					sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
-					                 self.dataset.indexed_dataset[idx_map]["target_len"]
+				idx_map = self.shuffle_idx[idx]
+				# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+				sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
+				                 self.dataset.indexed_dataset[idx_map]["target_len"]
 
 				datalen_with_index.append([idx, sample_len_cur])
 			
 			datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
 			for item in datalen_with_index_sort:
-				idx, sample_len_cur = item
-				if sample_len_cur > self.max_length_token:
+				idx, sample_len_cur_raw = item
+				if sample_len_cur_raw > self.max_length_token:
 					continue
-				max_token_cur = max(max_token, sample_len_cur)
-				max_token_padding = (1 + num_sample) * max_token_cur
+
+				max_token_cur = max(max_token, sample_len_cur_raw)
+				max_token_padding = 1 + num_sample
+				if self.batch_size_type == 'token':
+					max_token_padding *= max_token_cur
 				if max_token_padding <= self.batch_size:
 					batch.append(idx)
 					max_token = max_token_cur
 					num_sample += 1
 				else:
 					yield batch
-					max_token = sample_len_cur
-					num_sample = 1
 					batch = [idx]
+					max_token = sample_len_cur_raw
+					num_sample = 1
+					
 		
\ No newline at end of file
diff --git a/funasr/datasets/dataloader_fn.py b/funasr/datasets/dataloader_fn.py
new file mode 100644
index 0000000..8e8e423
--- /dev/null
+++ b/funasr/datasets/dataloader_fn.py
@@ -0,0 +1,53 @@
+
+import torch
+from funasr.datasets.dataset_jsonl import AudioDataset
+from funasr.datasets.data_sampler import BatchSampler
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.tokenizer.build_tokenizer import build_tokenizer
+from funasr.tokenizer.token_id_converter import TokenIDConverter
+collate_fn = None
+# collate_fn = collate_fn,
+
+jsonl = "/Users/zhifu/funasr_github/test_local/all_task_debug_len.jsonl"
+
+frontend = WavFrontend()
+token_type = 'char'
+bpemodel = None
+delimiter = None
+space_symbol = "<space>"
+non_linguistic_symbols = None
+g2p_type = None
+
+tokenizer = build_tokenizer(
+    token_type=token_type,
+    bpemodel=bpemodel,
+    delimiter=delimiter,
+    space_symbol=space_symbol,
+    non_linguistic_symbols=non_linguistic_symbols,
+    g2p_type=g2p_type,
+)
+token_list = ""
+unk_symbol = "<unk>"
+
+token_id_converter = TokenIDConverter(
+    token_list=token_list,
+    unk_symbol=unk_symbol,
+)
+
+dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer)
+batch_sampler = BatchSampler(dataset)
+dataloader_tr = torch.utils.data.DataLoader(dataset,
+                           collate_fn=dataset.collator,
+                           batch_sampler=batch_sampler,
+                           shuffle=False,
+                           num_workers=0,
+                           pin_memory=True)
+
+print(len(dataset))
+for i in range(3):
+    print(i)
+    for data in dataloader_tr:
+        print(len(data), data)
+# data_iter = iter(dataloader_tr)
+# data = next(data_iter)
+pass
diff --git a/funasr/datasets/dataset_jsonl.py b/funasr/datasets/dataset_jsonl.py
index 283fbd9..72d9a99 100644
--- a/funasr/datasets/dataset_jsonl.py
+++ b/funasr/datasets/dataset_jsonl.py
@@ -1,12 +1,41 @@
 import torch
 import json
 import torch.distributed as dist
+import numpy as np
+import kaldiio
+import librosa
 
-class AudioDatasetJsonl(torch.utils.data.Dataset):
+
+
+def load_audio(audio_path: str, fs: int=16000):
+	audio = None
+	if audio_path.startswith("oss:"):
+		pass
+	elif audio_path.startswith("odps:"):
+		pass
+	else:
+		if ".ark:" in audio_path:
+			audio = kaldiio.load_mat(audio_path)
+		else:
+			audio, fs = librosa.load(audio_path, sr=fs)
+	return audio
+
+def extract_features(data, date_type: str="sound", frontend=None):
+	if date_type == "sound":
+		feat, feats_lens = frontend(data, len(data))
+		feat = feat[0, :, :]
+	else:
+		feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32)
+	return feat, feats_lens
 	
-	def __init__(self, path, data_parallel_rank=0, data_parallel_size=1):
+	
+
+class IndexedDatasetJsonl(torch.utils.data.Dataset):
+	
+	def __init__(self, path):
 		super().__init__()
-		data_parallel_size = dist.get_world_size()
+		# data_parallel_size = dist.get_world_size()
+		data_parallel_size = 1
 		contents = []
 		with open(path, encoding='utf-8') as fin:
 			for line in fin:
@@ -31,7 +60,8 @@
 		self.contents = []
 		total_num = len(contents)
 		num_per_rank = total_num // data_parallel_size
-		rank = dist.get_rank()
+		# rank = dist.get_rank()
+		rank = 0
 		# import ipdb; ipdb.set_trace()
 		self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
 
@@ -41,3 +71,54 @@
 	
 	def __getitem__(self, index):
 		return self.contents[index]
+
+
+class AudioDataset(torch.utils.data.Dataset):
+	def __init__(self, path, frontend=None, tokenizer=None):
+		super().__init__()
+		self.indexed_dataset = IndexedDatasetJsonl(path)
+		self.frontend = frontend.forward
+		self.fs = 16000 if frontend is None else frontend.fs
+		self.data_type = "sound"
+		self.tokenizer = tokenizer
+		self.int_pad_value = -1
+		self.float_pad_value = 0.0
+
+	
+
+	
+	def __len__(self):
+		return len(self.indexed_dataset)
+	
+	def __getitem__(self, index):
+		item = self.indexed_dataset[index]
+		source = item["source"]
+		data_src = load_audio(source, fs=self.fs)
+		speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
+		target = item["target"]
+		text = self.tokenizer.encode(target)
+		text_lengths = len(text)
+		text, text_lengths = torch.tensor(text, dtype=torch.int64), torch.tensor([text_lengths], dtype=torch.int32)
+		return {"speech": speech,
+		        "speech_lengths": speech_lengths,
+		        "text": text,
+		        "text_lengths": text_lengths,
+		        }
+	
+	
+	def collator(self, samples: list=None):
+		
+		outputs = {}
+		for sample in samples:
+			for key in sample.keys():
+				if key not in outputs:
+					outputs[key] = []
+				outputs[key].append(sample[key])
+		
+		for key, data_list in outputs.items():
+			if data_list[0].dtype.kind == "i":
+				pad_value = self.int_pad_value
+			else:
+				pad_value = self.float_pad_value
+			outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+		return samples
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 6c2da2a..134b20a 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -9,7 +9,7 @@
 
 from funasr.datasets.large_datasets.dataset import Dataset
 from funasr.iterators.abs_iter_factory import AbsIterFactory
-from funasr.text.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
 
 
 def read_symbol_table(symbol_table_file):
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 26e062c..b303418 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -13,9 +13,9 @@
 import librosa
 import jieba
 
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.cleaner import TextCleaner
-from funasr.text.token_id_converter import TokenIDConverter
+from funasr.tokenizer.build_tokenizer import build_tokenizer
+from funasr.tokenizer.cleaner import TextCleaner
+from funasr.tokenizer.token_id_converter import TokenIDConverter
 
 
 class AbsPreprocessor(ABC):
diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py
index f0d3c9a..01a8c6f 100644
--- a/funasr/datasets/small_datasets/preprocessor.py
+++ b/funasr/datasets/small_datasets/preprocessor.py
@@ -11,9 +11,9 @@
 import scipy.signal
 import librosa
 
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.cleaner import TextCleaner
-from funasr.text.token_id_converter import TokenIDConverter
+from funasr.tokenizer.build_tokenizer import build_tokenizer
+from funasr.tokenizer.cleaner import TextCleaner
+from funasr.tokenizer.token_id_converter import TokenIDConverter
 
 
 class AbsPreprocessor(ABC):
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index 94447dc..b7b0889 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -1,7 +1,7 @@
 from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
 from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
 from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
-from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
+# from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
 
 from funasr.models.e2e_vad import E2EVadModel
 from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
@@ -30,8 +30,8 @@
         return [encoder, decoder]
     elif isinstance(model, Paraformer):
         return Paraformer_export(model, **export_config)
-    elif isinstance(model, Conformer_export):
-        return Conformer_export(model, **export_config)
+    # elif isinstance(model, Conformer_export):
+    #     return Conformer_export(model, **export_config)
     elif isinstance(model, E2EVadModel):
         return E2EVadModel_export(model, **export_config)
     elif isinstance(model, PunctuationModel):
diff --git a/funasr/export/models/e2e_asr_conformer.py b/funasr/export/models/e2e_asr_conformer.py
deleted file mode 100644
index 45feda5..0000000
--- a/funasr/export/models/e2e_asr_conformer.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import os
-import logging
-import torch
-import torch.nn as nn
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.models.decoder.transformer_decoder import TransformerDecoder
-from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
-from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
-
-class Conformer(nn.Module):
-    """
-    export conformer into onnx format
-    """
-
-    def __init__(
-            self,
-            model,
-            max_seq_len=512,
-            feats_dim=560,
-            model_name='model',
-            **kwargs,
-    ):
-        super().__init__()
-        onnx = False
-        if "onnx" in kwargs:
-            onnx = kwargs["onnx"]
-        if isinstance(model.encoder, ConformerEncoder):
-            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
-        elif isinstance(model.decoder, TransformerDecoder):
-            self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
-        
-        self.feats_dim = feats_dim
-        self.model_name = model_name
-
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
-    def _export_model(self, model, verbose, path):
-        dummy_input = model.get_dummy_inputs()
-        model_script = model
-        model_path = os.path.join(path, f'{model.model_name}.onnx')
-        if not os.path.exists(model_path):
-            torch.onnx.export(
-                model_script,
-                dummy_input,
-                model_path,
-                verbose=verbose,
-                opset_version=14,
-                input_names=model.get_input_names(),
-                output_names=model.get_output_names(),
-                dynamic_axes=model.get_dynamic_axes()
-            )
-
-    def _export_encoder_onnx(self, verbose, path):
-        model_encoder = self.encoder
-        self._export_model(model_encoder, verbose, path)
-
-    def _export_decoder_onnx(self, verbose, path):
-        model_decoder = self.decoder
-        self._export_model(model_decoder, verbose, path)
-
-    def _export_onnx(self, verbose, path):
-        self._export_encoder_onnx(verbose, path)
-        self._export_decoder_onnx(verbose, path)
\ No newline at end of file
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index ca5aed6..f92f322 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -145,9 +145,12 @@
             feats_lens.append(feat_length)
 
         feats_lens = torch.as_tensor(feats_lens)
-        feats_pad = pad_sequence(feats,
-                                 batch_first=True,
-                                 padding_value=0.0)
+        if batch_size == 1:
+            feats_pad = feats[0][None, :, :]
+        else:
+            feats_pad = pad_sequence(feats,
+                                     batch_first=True,
+                                     padding_value=0.0)
         return feats_pad, feats_lens
 
     def forward_fbank(
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 59d78e9..ce316f7 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -76,7 +76,7 @@
 from funasr.models.specaug.specaug import SpecAugLFR
 from funasr.modules.subsampling import Conv1dSubsampling
 from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.initialize import initialize
 from funasr.models.base_model import FunASRModel
 from funasr.train.class_choices import ClassChoices
diff --git a/funasr/tasks/data2vec.py b/funasr/tasks/data2vec.py
index b11d7de..80368f1 100644
--- a/funasr/tasks/data2vec.py
+++ b/funasr/tasks/data2vec.py
@@ -25,7 +25,7 @@
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.models.specaug.specaug import SpecAug
 from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.initialize import initialize
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index c0259a8..d5445b2 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -17,7 +17,7 @@
 from funasr.models.seq_rnn_lm import SequentialRNNLM
 from funasr.models.transformer_lm import TransformerLM
 from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.initialize import initialize
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index de5c897..dd5fe57 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -16,7 +16,7 @@
 from funasr.models.target_delay_transformer import TargetDelayTransformer
 from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
 from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.initialize import initialize
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py
index e7ee5a3..6bf918f 100644
--- a/funasr/tasks/sa_asr.py
+++ b/funasr/tasks/sa_asr.py
@@ -71,7 +71,7 @@
 from funasr.models.base_model import FunASRModel
 from funasr.modules.subsampling import Conv1dSubsampling
 from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.initialize import initialize
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
diff --git a/funasr/tasks/whisper.py b/funasr/tasks/whisper.py
index 7eef01e..e26227c 100644
--- a/funasr/tasks/whisper.py
+++ b/funasr/tasks/whisper.py
@@ -76,7 +76,7 @@
 from funasr.models.specaug.specaug import SpecAugLFR
 from funasr.modules.subsampling import Conv1dSubsampling
 from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.initialize import initialize
 from funasr.models.base_model import FunASRModel
 from funasr.train.class_choices import ClassChoices
diff --git a/funasr/text/__init__.py b/funasr/tokenizer/__init__.py
similarity index 100%
rename from funasr/text/__init__.py
rename to funasr/tokenizer/__init__.py
diff --git a/funasr/text/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
similarity index 100%
rename from funasr/text/abs_tokenizer.py
rename to funasr/tokenizer/abs_tokenizer.py
diff --git a/funasr/text/build_tokenizer.py b/funasr/tokenizer/build_tokenizer.py
similarity index 85%
rename from funasr/text/build_tokenizer.py
rename to funasr/tokenizer/build_tokenizer.py
index c60a335..9d1cdc3 100644
--- a/funasr/text/build_tokenizer.py
+++ b/funasr/tokenizer/build_tokenizer.py
@@ -3,11 +3,11 @@
 from typing import Union
 
 
-from funasr.text.abs_tokenizer import AbsTokenizer
-from funasr.text.char_tokenizer import CharTokenizer
-from funasr.text.phoneme_tokenizer import PhonemeTokenizer
-from funasr.text.sentencepiece_tokenizer import SentencepiecesTokenizer
-from funasr.text.word_tokenizer import WordTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.char_tokenizer import CharTokenizer
+from funasr.tokenizer.phoneme_tokenizer import PhonemeTokenizer
+from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
+from funasr.tokenizer.word_tokenizer import WordTokenizer
 
 
 def build_tokenizer(
diff --git a/funasr/text/char_tokenizer.py b/funasr/tokenizer/char_tokenizer.py
similarity index 97%
rename from funasr/text/char_tokenizer.py
rename to funasr/tokenizer/char_tokenizer.py
index 8d1daf4..6c9a5a5 100644
--- a/funasr/text/char_tokenizer.py
+++ b/funasr/tokenizer/char_tokenizer.py
@@ -5,7 +5,7 @@
 import warnings
 
 
-from funasr.text.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
 
 
 class CharTokenizer(AbsTokenizer):
diff --git a/funasr/text/cleaner.py b/funasr/tokenizer/cleaner.py
similarity index 100%
rename from funasr/text/cleaner.py
rename to funasr/tokenizer/cleaner.py
diff --git a/funasr/text/korean_cleaner.py b/funasr/tokenizer/korean_cleaner.py
similarity index 100%
rename from funasr/text/korean_cleaner.py
rename to funasr/tokenizer/korean_cleaner.py
diff --git a/funasr/text/phoneme_tokenizer.py b/funasr/tokenizer/phoneme_tokenizer.py
similarity index 98%
rename from funasr/text/phoneme_tokenizer.py
rename to funasr/tokenizer/phoneme_tokenizer.py
index ad3d81c..0117c6a 100644
--- a/funasr/text/phoneme_tokenizer.py
+++ b/funasr/tokenizer/phoneme_tokenizer.py
@@ -10,7 +10,7 @@
 # import g2p_en
 import jamo
 
-from funasr.text.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
 
 
 g2p_choices = [
@@ -107,7 +107,7 @@
         List[str]: List of phoneme + prosody symbols.
 
     Examples:
-        >>> from funasr.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody
+        >>> from funasr.tokenizer.phoneme_tokenizer import pyopenjtalk_g2p_prosody
         >>> pyopenjtalk_g2p_prosody("銇撱倱銇仭銇��")
         ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']
 
diff --git a/funasr/text/sentencepiece_tokenizer.py b/funasr/tokenizer/sentencepiece_tokenizer.py
similarity index 95%
rename from funasr/text/sentencepiece_tokenizer.py
rename to funasr/tokenizer/sentencepiece_tokenizer.py
index e393cee..9a65920 100644
--- a/funasr/text/sentencepiece_tokenizer.py
+++ b/funasr/tokenizer/sentencepiece_tokenizer.py
@@ -5,7 +5,7 @@
 
 import sentencepiece as spm
 
-from funasr.text.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
 
 
 class SentencepiecesTokenizer(AbsTokenizer):
diff --git a/funasr/text/token_id_converter.py b/funasr/tokenizer/token_id_converter.py
similarity index 100%
rename from funasr/text/token_id_converter.py
rename to funasr/tokenizer/token_id_converter.py
diff --git a/funasr/text/word_tokenizer.py b/funasr/tokenizer/word_tokenizer.py
similarity index 96%
rename from funasr/text/word_tokenizer.py
rename to funasr/tokenizer/word_tokenizer.py
index f4d33d5..cbd0673 100644
--- a/funasr/text/word_tokenizer.py
+++ b/funasr/tokenizer/word_tokenizer.py
@@ -5,7 +5,7 @@
 import warnings
 
 
-from funasr.text.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
 
 
 class WordTokenizer(AbsTokenizer):

--
Gitblit v1.9.1