From 7dadb793e639d2b7f918f2f915e928a63e016ea5 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 23 十一月 2023 16:04:37 +0800
Subject: [PATCH] Dev gzf funasr2 (#1111)
---
funasr/datasets/small_datasets/preprocessor.py | 6
funasr/datasets/large_datasets/build_dataloader.py | 2
funasr/datasets/data_sampler.py | 54 +++++---
funasr/tasks/sa_asr.py | 2
funasr/export/models/__init__.py | 6
funasr/tasks/punctuation.py | 2
funasr/tokenizer/sentencepiece_tokenizer.py | 2
funasr/bin/train.py | 2
funasr/datasets/preprocessor.py | 6
funasr/tokenizer/build_tokenizer.py | 10
funasr/datasets/dataloader_fn.py | 53 ++++++++
funasr/datasets/dataset_jsonl.py | 89 ++++++++++++++
funasr/tokenizer/word_tokenizer.py | 2
funasr/tokenizer/korean_cleaner.py | 0
funasr/tokenizer/char_tokenizer.py | 2
funasr/tasks/data2vec.py | 2
funasr/tasks/lm.py | 2
funasr/tasks/whisper.py | 2
funasr/tasks/asr.py | 2
funasr/tokenizer/cleaner.py | 0
funasr/tokenizer/__init__.py | 0
funasr/models/frontend/wav_frontend.py | 9 +
funasr/bin/asr_infer.py | 4
/dev/null | 69 -----------
funasr/bin/tp_infer.py | 2
funasr/tokenizer/abs_tokenizer.py | 0
funasr/tokenizer/phoneme_tokenizer.py | 4
funasr/tokenizer/token_id_converter.py | 0
funasr/bin/build_trainer.py | 2
funasr/bin/tokenize_text.py | 6
30 files changed, 212 insertions(+), 130 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index c1d08df..a1cede1 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -34,8 +34,8 @@
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.build_utils.build_asr_model import frontend_choices
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.token_id_converter import TokenIDConverter
+from funasr.tokenizer.build_tokenizer import build_tokenizer
+from funasr.tokenizer.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index bda83ec..c03bdf3 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -18,7 +18,7 @@
from funasr.build_utils.build_scheduler import build_scheduler
from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope
from funasr.modules.lora.utils import mark_only_lora_as_trainable
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
diff --git a/funasr/bin/tokenize_text.py b/funasr/bin/tokenize_text.py
index 6ec83a8..674c1b9 100755
--- a/funasr/bin/tokenize_text.py
+++ b/funasr/bin/tokenize_text.py
@@ -9,9 +9,9 @@
from funasr.utils.cli_utils import get_commandline_args
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.cleaner import TextCleaner
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.build_tokenizer import build_tokenizer
+from funasr.tokenizer.cleaner import TextCleaner
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
diff --git a/funasr/bin/tp_infer.py b/funasr/bin/tp_infer.py
index ede579c..cfe534f 100644
--- a/funasr/bin/tp_infer.py
+++ b/funasr/bin/tp_infer.py
@@ -11,7 +11,7 @@
import torch
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.text.token_id_converter import TokenIDConverter
+from funasr.tokenizer.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index f5d10c4..6aebf8a 100755
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -17,7 +17,7 @@
from funasr.build_utils.build_optimizer import build_optimizer
from funasr.build_utils.build_scheduler import build_scheduler
from funasr.build_utils.build_trainer import build_trainer
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
diff --git a/funasr/datasets/data_sampler.py b/funasr/datasets/data_sampler.py
index 2875d8d..6b3407c 100644
--- a/funasr/datasets/data_sampler.py
+++ b/funasr/datasets/data_sampler.py
@@ -1,29 +1,42 @@
import torch
+import numpy as np
+
class BatchSampler(torch.utils.data.BatchSampler):
- def __init__(self, dataset=None, args=None, drop_last=True, ):
+ def __init__(self, dataset, batch_size_type: str="example", batch_size: int=14, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
self.drop_last = drop_last
self.pre_idx = -1
self.dataset = dataset
- self.batch_size_type = args.batch_size_type
- self.batch_size = args.batch_size
- self.sort_size = args.sort_size
- self.max_length_token = args.max_length_token
self.total_samples = len(dataset)
+ # self.batch_size_type = args.batch_size_type
+ # self.batch_size = args.batch_size
+ # self.sort_size = args.sort_size
+ # self.max_length_token = args.max_length_token
+ self.batch_size_type = batch_size_type
+ self.batch_size = batch_size
+ self.sort_size = sort_size
+ self.max_length_token = kwargs.get("max_length_token", 5000)
+ self.shuffle_idx = np.arange(self.total_samples)
+ self.shuffle = shuffle
def __len__(self):
return self.total_samples
-
def __iter__(self):
+ print("in sampler")
+
+ if self.shuffle:
+ np.random.shuffle(self.shuffle_idx)
+
batch = []
max_token = 0
num_sample = 0
-
+
iter_num = (self.total_samples-1) // self.sort_size + 1
+ print("iter_num: ", iter_num)
for iter in range(self.pre_idx + 1, iter_num):
datalen_with_index = []
for i in range(self.sort_size):
@@ -31,30 +44,31 @@
if idx >= self.total_samples:
continue
- if self.batch_size_type == "example":
- sample_len_cur = 1
- else:
- idx_map = self.dataset.shuffle_idx[idx]
- # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
- sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
- self.dataset.indexed_dataset[idx_map]["target_len"]
+ idx_map = self.shuffle_idx[idx]
+ # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
+ sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
+ self.dataset.indexed_dataset[idx_map]["target_len"]
datalen_with_index.append([idx, sample_len_cur])
datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
for item in datalen_with_index_sort:
- idx, sample_len_cur = item
- if sample_len_cur > self.max_length_token:
+ idx, sample_len_cur_raw = item
+ if sample_len_cur_raw > self.max_length_token:
continue
- max_token_cur = max(max_token, sample_len_cur)
- max_token_padding = (1 + num_sample) * max_token_cur
+
+ max_token_cur = max(max_token, sample_len_cur_raw)
+ max_token_padding = 1 + num_sample
+ if self.batch_size_type == 'token':
+ max_token_padding *= max_token_cur
if max_token_padding <= self.batch_size:
batch.append(idx)
max_token = max_token_cur
num_sample += 1
else:
yield batch
- max_token = sample_len_cur
- num_sample = 1
batch = [idx]
+ max_token = sample_len_cur_raw
+ num_sample = 1
+
\ No newline at end of file
diff --git a/funasr/datasets/dataloader_fn.py b/funasr/datasets/dataloader_fn.py
new file mode 100644
index 0000000..8e8e423
--- /dev/null
+++ b/funasr/datasets/dataloader_fn.py
@@ -0,0 +1,53 @@
+
+import torch
+from funasr.datasets.dataset_jsonl import AudioDataset
+from funasr.datasets.data_sampler import BatchSampler
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.tokenizer.build_tokenizer import build_tokenizer
+from funasr.tokenizer.token_id_converter import TokenIDConverter
+collate_fn = None
+# collate_fn = collate_fn,
+
+jsonl = "/Users/zhifu/funasr_github/test_local/all_task_debug_len.jsonl"
+
+frontend = WavFrontend()
+token_type = 'char'
+bpemodel = None
+delimiter = None
+space_symbol = "<space>"
+non_linguistic_symbols = None
+g2p_type = None
+
+tokenizer = build_tokenizer(
+ token_type=token_type,
+ bpemodel=bpemodel,
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+)
+token_list = ""
+unk_symbol = "<unk>"
+
+token_id_converter = TokenIDConverter(
+ token_list=token_list,
+ unk_symbol=unk_symbol,
+)
+
+dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer)
+batch_sampler = BatchSampler(dataset)
+dataloader_tr = torch.utils.data.DataLoader(dataset,
+ collate_fn=dataset.collator,
+ batch_sampler=batch_sampler,
+ shuffle=False,
+ num_workers=0,
+ pin_memory=True)
+
+print(len(dataset))
+for i in range(3):
+ print(i)
+ for data in dataloader_tr:
+ print(len(data), data)
+# data_iter = iter(dataloader_tr)
+# data = next(data_iter)
+pass
diff --git a/funasr/datasets/dataset_jsonl.py b/funasr/datasets/dataset_jsonl.py
index 283fbd9..72d9a99 100644
--- a/funasr/datasets/dataset_jsonl.py
+++ b/funasr/datasets/dataset_jsonl.py
@@ -1,12 +1,41 @@
import torch
import json
import torch.distributed as dist
+import numpy as np
+import kaldiio
+import librosa
-class AudioDatasetJsonl(torch.utils.data.Dataset):
+
+
+def load_audio(audio_path: str, fs: int=16000):
+ audio = None
+ if audio_path.startswith("oss:"):
+ pass
+ elif audio_path.startswith("odps:"):
+ pass
+ else:
+ if ".ark:" in audio_path:
+ audio = kaldiio.load_mat(audio_path)
+ else:
+ audio, fs = librosa.load(audio_path, sr=fs)
+ return audio
+
+def extract_features(data, date_type: str="sound", frontend=None):
+ if date_type == "sound":
+ feat, feats_lens = frontend(data, len(data))
+ feat = feat[0, :, :]
+ else:
+ feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32)
+ return feat, feats_lens
- def __init__(self, path, data_parallel_rank=0, data_parallel_size=1):
+
+
+class IndexedDatasetJsonl(torch.utils.data.Dataset):
+
+ def __init__(self, path):
super().__init__()
- data_parallel_size = dist.get_world_size()
+ # data_parallel_size = dist.get_world_size()
+ data_parallel_size = 1
contents = []
with open(path, encoding='utf-8') as fin:
for line in fin:
@@ -31,7 +60,8 @@
self.contents = []
total_num = len(contents)
num_per_rank = total_num // data_parallel_size
- rank = dist.get_rank()
+ # rank = dist.get_rank()
+ rank = 0
# import ipdb; ipdb.set_trace()
self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
@@ -41,3 +71,54 @@
def __getitem__(self, index):
return self.contents[index]
+
+
+class AudioDataset(torch.utils.data.Dataset):
+ def __init__(self, path, frontend=None, tokenizer=None):
+ super().__init__()
+ self.indexed_dataset = IndexedDatasetJsonl(path)
+ self.frontend = frontend.forward
+ self.fs = 16000 if frontend is None else frontend.fs
+ self.data_type = "sound"
+ self.tokenizer = tokenizer
+ self.int_pad_value = -1
+ self.float_pad_value = 0.0
+
+
+
+
+ def __len__(self):
+ return len(self.indexed_dataset)
+
+ def __getitem__(self, index):
+ item = self.indexed_dataset[index]
+ source = item["source"]
+ data_src = load_audio(source, fs=self.fs)
+ speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
+ target = item["target"]
+ text = self.tokenizer.encode(target)
+ text_lengths = len(text)
+ text, text_lengths = torch.tensor(text, dtype=torch.int64), torch.tensor([text_lengths], dtype=torch.int32)
+ return {"speech": speech,
+ "speech_lengths": speech_lengths,
+ "text": text,
+ "text_lengths": text_lengths,
+ }
+
+
+ def collator(self, samples: list=None):
+
+ outputs = {}
+ for sample in samples:
+ for key in sample.keys():
+ if key not in outputs:
+ outputs[key] = []
+ outputs[key].append(sample[key])
+
+ for key, data_list in outputs.items():
+ if data_list[0].dtype.kind == "i":
+ pad_value = self.int_pad_value
+ else:
+ pad_value = self.float_pad_value
+ outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+ return samples
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 6c2da2a..134b20a 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -9,7 +9,7 @@
from funasr.datasets.large_datasets.dataset import Dataset
from funasr.iterators.abs_iter_factory import AbsIterFactory
-from funasr.text.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
def read_symbol_table(symbol_table_file):
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 26e062c..b303418 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -13,9 +13,9 @@
import librosa
import jieba
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.cleaner import TextCleaner
-from funasr.text.token_id_converter import TokenIDConverter
+from funasr.tokenizer.build_tokenizer import build_tokenizer
+from funasr.tokenizer.cleaner import TextCleaner
+from funasr.tokenizer.token_id_converter import TokenIDConverter
class AbsPreprocessor(ABC):
diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py
index f0d3c9a..01a8c6f 100644
--- a/funasr/datasets/small_datasets/preprocessor.py
+++ b/funasr/datasets/small_datasets/preprocessor.py
@@ -11,9 +11,9 @@
import scipy.signal
import librosa
-from funasr.text.build_tokenizer import build_tokenizer
-from funasr.text.cleaner import TextCleaner
-from funasr.text.token_id_converter import TokenIDConverter
+from funasr.tokenizer.build_tokenizer import build_tokenizer
+from funasr.tokenizer.cleaner import TextCleaner
+from funasr.tokenizer.token_id_converter import TokenIDConverter
class AbsPreprocessor(ABC):
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index 94447dc..b7b0889 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -1,7 +1,7 @@
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
-from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
+# from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
@@ -30,8 +30,8 @@
return [encoder, decoder]
elif isinstance(model, Paraformer):
return Paraformer_export(model, **export_config)
- elif isinstance(model, Conformer_export):
- return Conformer_export(model, **export_config)
+ # elif isinstance(model, Conformer_export):
+ # return Conformer_export(model, **export_config)
elif isinstance(model, E2EVadModel):
return E2EVadModel_export(model, **export_config)
elif isinstance(model, PunctuationModel):
diff --git a/funasr/export/models/e2e_asr_conformer.py b/funasr/export/models/e2e_asr_conformer.py
deleted file mode 100644
index 45feda5..0000000
--- a/funasr/export/models/e2e_asr_conformer.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import os
-import logging
-import torch
-import torch.nn as nn
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.models.decoder.transformer_decoder import TransformerDecoder
-from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
-from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
-
-class Conformer(nn.Module):
- """
- export conformer into onnx format
- """
-
- def __init__(
- self,
- model,
- max_seq_len=512,
- feats_dim=560,
- model_name='model',
- **kwargs,
- ):
- super().__init__()
- onnx = False
- if "onnx" in kwargs:
- onnx = kwargs["onnx"]
- if isinstance(model.encoder, ConformerEncoder):
- self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
- elif isinstance(model.decoder, TransformerDecoder):
- self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
-
- self.feats_dim = feats_dim
- self.model_name = model_name
-
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
- def _export_model(self, model, verbose, path):
- dummy_input = model.get_dummy_inputs()
- model_script = model
- model_path = os.path.join(path, f'{model.model_name}.onnx')
- if not os.path.exists(model_path):
- torch.onnx.export(
- model_script,
- dummy_input,
- model_path,
- verbose=verbose,
- opset_version=14,
- input_names=model.get_input_names(),
- output_names=model.get_output_names(),
- dynamic_axes=model.get_dynamic_axes()
- )
-
- def _export_encoder_onnx(self, verbose, path):
- model_encoder = self.encoder
- self._export_model(model_encoder, verbose, path)
-
- def _export_decoder_onnx(self, verbose, path):
- model_decoder = self.decoder
- self._export_model(model_decoder, verbose, path)
-
- def _export_onnx(self, verbose, path):
- self._export_encoder_onnx(verbose, path)
- self._export_decoder_onnx(verbose, path)
\ No newline at end of file
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index ca5aed6..f92f322 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -145,9 +145,12 @@
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
- feats_pad = pad_sequence(feats,
- batch_first=True,
- padding_value=0.0)
+ if batch_size == 1:
+ feats_pad = feats[0][None, :, :]
+ else:
+ feats_pad = pad_sequence(feats,
+ batch_first=True,
+ padding_value=0.0)
return feats_pad, feats_lens
def forward_fbank(
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 59d78e9..ce316f7 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -76,7 +76,7 @@
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
diff --git a/funasr/tasks/data2vec.py b/funasr/tasks/data2vec.py
index b11d7de..80368f1 100644
--- a/funasr/tasks/data2vec.py
+++ b/funasr/tasks/data2vec.py
@@ -25,7 +25,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index c0259a8..d5445b2 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -17,7 +17,7 @@
from funasr.models.seq_rnn_lm import SequentialRNNLM
from funasr.models.transformer_lm import TransformerLM
from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index de5c897..dd5fe57 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -16,7 +16,7 @@
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py
index e7ee5a3..6bf918f 100644
--- a/funasr/tasks/sa_asr.py
+++ b/funasr/tasks/sa_asr.py
@@ -71,7 +71,7 @@
from funasr.models.base_model import FunASRModel
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
diff --git a/funasr/tasks/whisper.py b/funasr/tasks/whisper.py
index 7eef01e..e26227c 100644
--- a/funasr/tasks/whisper.py
+++ b/funasr/tasks/whisper.py
@@ -76,7 +76,7 @@
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
diff --git a/funasr/text/__init__.py b/funasr/tokenizer/__init__.py
similarity index 100%
rename from funasr/text/__init__.py
rename to funasr/tokenizer/__init__.py
diff --git a/funasr/text/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
similarity index 100%
rename from funasr/text/abs_tokenizer.py
rename to funasr/tokenizer/abs_tokenizer.py
diff --git a/funasr/text/build_tokenizer.py b/funasr/tokenizer/build_tokenizer.py
similarity index 85%
rename from funasr/text/build_tokenizer.py
rename to funasr/tokenizer/build_tokenizer.py
index c60a335..9d1cdc3 100644
--- a/funasr/text/build_tokenizer.py
+++ b/funasr/tokenizer/build_tokenizer.py
@@ -3,11 +3,11 @@
from typing import Union
-from funasr.text.abs_tokenizer import AbsTokenizer
-from funasr.text.char_tokenizer import CharTokenizer
-from funasr.text.phoneme_tokenizer import PhonemeTokenizer
-from funasr.text.sentencepiece_tokenizer import SentencepiecesTokenizer
-from funasr.text.word_tokenizer import WordTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.char_tokenizer import CharTokenizer
+from funasr.tokenizer.phoneme_tokenizer import PhonemeTokenizer
+from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
+from funasr.tokenizer.word_tokenizer import WordTokenizer
def build_tokenizer(
diff --git a/funasr/text/char_tokenizer.py b/funasr/tokenizer/char_tokenizer.py
similarity index 97%
rename from funasr/text/char_tokenizer.py
rename to funasr/tokenizer/char_tokenizer.py
index 8d1daf4..6c9a5a5 100644
--- a/funasr/text/char_tokenizer.py
+++ b/funasr/tokenizer/char_tokenizer.py
@@ -5,7 +5,7 @@
import warnings
-from funasr.text.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
class CharTokenizer(AbsTokenizer):
diff --git a/funasr/text/cleaner.py b/funasr/tokenizer/cleaner.py
similarity index 100%
rename from funasr/text/cleaner.py
rename to funasr/tokenizer/cleaner.py
diff --git a/funasr/text/korean_cleaner.py b/funasr/tokenizer/korean_cleaner.py
similarity index 100%
rename from funasr/text/korean_cleaner.py
rename to funasr/tokenizer/korean_cleaner.py
diff --git a/funasr/text/phoneme_tokenizer.py b/funasr/tokenizer/phoneme_tokenizer.py
similarity index 98%
rename from funasr/text/phoneme_tokenizer.py
rename to funasr/tokenizer/phoneme_tokenizer.py
index ad3d81c..0117c6a 100644
--- a/funasr/text/phoneme_tokenizer.py
+++ b/funasr/tokenizer/phoneme_tokenizer.py
@@ -10,7 +10,7 @@
# import g2p_en
import jamo
-from funasr.text.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
g2p_choices = [
@@ -107,7 +107,7 @@
List[str]: List of phoneme + prosody symbols.
Examples:
- >>> from funasr.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody
+ >>> from funasr.tokenizer.phoneme_tokenizer import pyopenjtalk_g2p_prosody
>>> pyopenjtalk_g2p_prosody("銇撱倱銇仭銇��")
['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']
diff --git a/funasr/text/sentencepiece_tokenizer.py b/funasr/tokenizer/sentencepiece_tokenizer.py
similarity index 95%
rename from funasr/text/sentencepiece_tokenizer.py
rename to funasr/tokenizer/sentencepiece_tokenizer.py
index e393cee..9a65920 100644
--- a/funasr/text/sentencepiece_tokenizer.py
+++ b/funasr/tokenizer/sentencepiece_tokenizer.py
@@ -5,7 +5,7 @@
import sentencepiece as spm
-from funasr.text.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
class SentencepiecesTokenizer(AbsTokenizer):
diff --git a/funasr/text/token_id_converter.py b/funasr/tokenizer/token_id_converter.py
similarity index 100%
rename from funasr/text/token_id_converter.py
rename to funasr/tokenizer/token_id_converter.py
diff --git a/funasr/text/word_tokenizer.py b/funasr/tokenizer/word_tokenizer.py
similarity index 96%
rename from funasr/text/word_tokenizer.py
rename to funasr/tokenizer/word_tokenizer.py
index f4d33d5..cbd0673 100644
--- a/funasr/text/word_tokenizer.py
+++ b/funasr/tokenizer/word_tokenizer.py
@@ -5,7 +5,7 @@
import warnings
-from funasr.text.abs_tokenizer import AbsTokenizer
+from funasr.tokenizer.abs_tokenizer import AbsTokenizer
class WordTokenizer(AbsTokenizer):
--
Gitblit v1.9.1