zhifu gao
2023-11-23 7dadb793e639d2b7f918f2f915e928a63e016ea5
Dev gzf funasr2 (#1111)

* update funasr.text -> funasr.tokenizer fix bug export
18个文件已修改
1个文件已添加
10 文件已重命名
1个文件已删除
342 ■■■■■ 已修改文件
funasr/bin/asr_infer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/build_trainer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/tokenize_text.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/tp_infer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/data_sampler.py 54 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/dataloader_fn.py 53 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/dataset_jsonl.py 89 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/build_dataloader.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/preprocessor.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/small_datasets/preprocessor.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/__init__.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/e2e_asr_conformer.py 69 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/wav_frontend.py 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/data2vec.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/lm.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/punctuation.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/sa_asr.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/whisper.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/abs_tokenizer.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/build_tokenizer.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/char_tokenizer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/cleaner.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/korean_cleaner.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/phoneme_tokenizer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/sentencepiece_tokenizer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/token_id_converter.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/word_tokenizer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py
@@ -34,8 +34,8 @@
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.build_utils.build_asr_model import frontend_choices
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.tokenizer.build_tokenizer import build_tokenizer
from funasr.tokenizer.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
funasr/bin/build_trainer.py
@@ -18,7 +18,7 @@
from funasr.build_utils.build_scheduler import build_scheduler
from funasr.build_utils.build_trainer import build_trainer as build_trainer_modelscope
from funasr.modules.lora.utils import mark_only_lora_as_trainable
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
funasr/bin/tokenize_text.py
@@ -9,9 +9,9 @@
from funasr.utils.cli_utils import get_commandline_args
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.tokenizer.build_tokenizer import build_tokenizer
from funasr.tokenizer.cleaner import TextCleaner
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
funasr/bin/tp_infer.py
@@ -11,7 +11,7 @@
import torch
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter
from funasr.tokenizer.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
funasr/bin/train.py
@@ -17,7 +17,7 @@
from funasr.build_utils.build_optimizer import build_optimizer
from funasr.build_utils.build_scheduler import build_scheduler
from funasr.build_utils.build_trainer import build_trainer
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
funasr/datasets/data_sampler.py
@@ -1,29 +1,42 @@
import torch
import numpy as np
class BatchSampler(torch.utils.data.BatchSampler):
    
    def __init__(self, dataset=None, args=None, drop_last=True, ):
    def __init__(self, dataset, batch_size_type: str="example", batch_size: int=14, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
        
        self.drop_last = drop_last
        self.pre_idx = -1
        self.dataset = dataset
        self.batch_size_type = args.batch_size_type
        self.batch_size = args.batch_size
        self.sort_size = args.sort_size
        self.max_length_token = args.max_length_token
        self.total_samples = len(dataset)
        # self.batch_size_type = args.batch_size_type
        # self.batch_size = args.batch_size
        # self.sort_size = args.sort_size
        # self.max_length_token = args.max_length_token
        self.batch_size_type = batch_size_type
        self.batch_size = batch_size
        self.sort_size = sort_size
        self.max_length_token = kwargs.get("max_length_token", 5000)
        self.shuffle_idx = np.arange(self.total_samples)
        self.shuffle = shuffle
    
    def __len__(self):
        return self.total_samples
    def __iter__(self):
        print("in sampler")
        if self.shuffle:
            np.random.shuffle(self.shuffle_idx)
        batch = []
        max_token = 0
        num_sample = 0
        iter_num = (self.total_samples-1) // self.sort_size + 1
        print("iter_num: ", iter_num)
        for iter in range(self.pre_idx + 1, iter_num):
            datalen_with_index = []
            for i in range(self.sort_size):
@@ -31,30 +44,31 @@
                if idx >= self.total_samples:
                    continue
                if self.batch_size_type == "example":
                    sample_len_cur = 1
                else:
                    idx_map = self.dataset.shuffle_idx[idx]
                    # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                    sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
                                     self.dataset.indexed_dataset[idx_map]["target_len"]
                idx_map = self.shuffle_idx[idx]
                # prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
                sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \
                                 self.dataset.indexed_dataset[idx_map]["target_len"]
                datalen_with_index.append([idx, sample_len_cur])
            
            datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
            for item in datalen_with_index_sort:
                idx, sample_len_cur = item
                if sample_len_cur > self.max_length_token:
                idx, sample_len_cur_raw = item
                if sample_len_cur_raw > self.max_length_token:
                    continue
                max_token_cur = max(max_token, sample_len_cur)
                max_token_padding = (1 + num_sample) * max_token_cur
                max_token_cur = max(max_token, sample_len_cur_raw)
                max_token_padding = 1 + num_sample
                if self.batch_size_type == 'token':
                    max_token_padding *= max_token_cur
                if max_token_padding <= self.batch_size:
                    batch.append(idx)
                    max_token = max_token_cur
                    num_sample += 1
                else:
                    yield batch
                    max_token = sample_len_cur
                    num_sample = 1
                    batch = [idx]
                    max_token = sample_len_cur_raw
                    num_sample = 1
        
funasr/datasets/dataloader_fn.py
New file
@@ -0,0 +1,53 @@
import torch
from funasr.datasets.dataset_jsonl import AudioDataset
from funasr.datasets.data_sampler import BatchSampler
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tokenizer.build_tokenizer import build_tokenizer
from funasr.tokenizer.token_id_converter import TokenIDConverter
collate_fn = None
# collate_fn = collate_fn,
jsonl = "/Users/zhifu/funasr_github/test_local/all_task_debug_len.jsonl"
frontend = WavFrontend()
token_type = 'char'
bpemodel = None
delimiter = None
space_symbol = "<space>"
non_linguistic_symbols = None
g2p_type = None
tokenizer = build_tokenizer(
    token_type=token_type,
    bpemodel=bpemodel,
    delimiter=delimiter,
    space_symbol=space_symbol,
    non_linguistic_symbols=non_linguistic_symbols,
    g2p_type=g2p_type,
)
token_list = ""
unk_symbol = "<unk>"
token_id_converter = TokenIDConverter(
    token_list=token_list,
    unk_symbol=unk_symbol,
)
dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer)
batch_sampler = BatchSampler(dataset)
dataloader_tr = torch.utils.data.DataLoader(dataset,
                           collate_fn=dataset.collator,
                           batch_sampler=batch_sampler,
                           shuffle=False,
                           num_workers=0,
                           pin_memory=True)
print(len(dataset))
for i in range(3):
    print(i)
    for data in dataloader_tr:
        print(len(data), data)
# data_iter = iter(dataloader_tr)
# data = next(data_iter)
pass
funasr/datasets/dataset_jsonl.py
@@ -1,12 +1,41 @@
import torch
import json
import torch.distributed as dist
import numpy as np
import kaldiio
import librosa
class AudioDatasetJsonl(torch.utils.data.Dataset):
def load_audio(audio_path: str, fs: int=16000):
    audio = None
    if audio_path.startswith("oss:"):
        pass
    elif audio_path.startswith("odps:"):
        pass
    else:
        if ".ark:" in audio_path:
            audio = kaldiio.load_mat(audio_path)
        else:
            audio, fs = librosa.load(audio_path, sr=fs)
    return audio
def extract_features(data, date_type: str="sound", frontend=None):
    if date_type == "sound":
        feat, feats_lens = frontend(data, len(data))
        feat = feat[0, :, :]
    else:
        feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32)
    return feat, feats_lens
    
    def __init__(self, path, data_parallel_rank=0, data_parallel_size=1):
class IndexedDatasetJsonl(torch.utils.data.Dataset):
    def __init__(self, path):
        super().__init__()
        data_parallel_size = dist.get_world_size()
        # data_parallel_size = dist.get_world_size()
        data_parallel_size = 1
        contents = []
        with open(path, encoding='utf-8') as fin:
            for line in fin:
@@ -31,7 +60,8 @@
        self.contents = []
        total_num = len(contents)
        num_per_rank = total_num // data_parallel_size
        rank = dist.get_rank()
        # rank = dist.get_rank()
        rank = 0
        # import ipdb; ipdb.set_trace()
        self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
@@ -41,3 +71,54 @@
    
    def __getitem__(self, index):
        return self.contents[index]
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, path, frontend=None, tokenizer=None):
        super().__init__()
        self.indexed_dataset = IndexedDatasetJsonl(path)
        self.frontend = frontend.forward
        self.fs = 16000 if frontend is None else frontend.fs
        self.data_type = "sound"
        self.tokenizer = tokenizer
        self.int_pad_value = -1
        self.float_pad_value = 0.0
    def __len__(self):
        return len(self.indexed_dataset)
    def __getitem__(self, index):
        item = self.indexed_dataset[index]
        source = item["source"]
        data_src = load_audio(source, fs=self.fs)
        speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
        target = item["target"]
        text = self.tokenizer.encode(target)
        text_lengths = len(text)
        text, text_lengths = torch.tensor(text, dtype=torch.int64), torch.tensor([text_lengths], dtype=torch.int32)
        return {"speech": speech,
                "speech_lengths": speech_lengths,
                "text": text,
                "text_lengths": text_lengths,
                }
    def collator(self, samples: list=None):
        outputs = {}
        for sample in samples:
            for key in sample.keys():
                if key not in outputs:
                    outputs[key] = []
                outputs[key].append(sample[key])
        for key, data_list in outputs.items():
            if data_list[0].dtype.kind == "i":
                pad_value = self.int_pad_value
            else:
                pad_value = self.float_pad_value
            outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
        return samples
funasr/datasets/large_datasets/build_dataloader.py
@@ -9,7 +9,7 @@
from funasr.datasets.large_datasets.dataset import Dataset
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.text.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
def read_symbol_table(symbol_table_file):
funasr/datasets/preprocessor.py
@@ -13,9 +13,9 @@
import librosa
import jieba
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
from funasr.text.token_id_converter import TokenIDConverter
from funasr.tokenizer.build_tokenizer import build_tokenizer
from funasr.tokenizer.cleaner import TextCleaner
from funasr.tokenizer.token_id_converter import TokenIDConverter
class AbsPreprocessor(ABC):
funasr/datasets/small_datasets/preprocessor.py
@@ -11,9 +11,9 @@
import scipy.signal
import librosa
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
from funasr.text.token_id_converter import TokenIDConverter
from funasr.tokenizer.build_tokenizer import build_tokenizer
from funasr.tokenizer.cleaner import TextCleaner
from funasr.tokenizer.token_id_converter import TokenIDConverter
class AbsPreprocessor(ABC):
funasr/export/models/__init__.py
@@ -1,7 +1,7 @@
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
# from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
@@ -30,8 +30,8 @@
        return [encoder, decoder]
    elif isinstance(model, Paraformer):
        return Paraformer_export(model, **export_config)
    elif isinstance(model, Conformer_export):
        return Conformer_export(model, **export_config)
    # elif isinstance(model, Conformer_export):
    #     return Conformer_export(model, **export_config)
    elif isinstance(model, E2EVadModel):
        return E2EVadModel_export(model, **export_config)
    elif isinstance(model, PunctuationModel):
funasr/export/models/e2e_asr_conformer.py
File was deleted
funasr/models/frontend/wav_frontend.py
@@ -145,9 +145,12 @@
            feats_lens.append(feat_length)
        feats_lens = torch.as_tensor(feats_lens)
        feats_pad = pad_sequence(feats,
                                 batch_first=True,
                                 padding_value=0.0)
        if batch_size == 1:
            feats_pad = feats[0][None, :, :]
        else:
            feats_pad = pad_sequence(feats,
                                     batch_first=True,
                                     padding_value=0.0)
        return feats_pad, feats_lens
    def forward_fbank(
funasr/tasks/asr.py
@@ -76,7 +76,7 @@
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
funasr/tasks/data2vec.py
@@ -25,7 +25,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
funasr/tasks/lm.py
@@ -17,7 +17,7 @@
from funasr.models.seq_rnn_lm import SequentialRNNLM
from funasr.models.transformer_lm import TransformerLM
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
funasr/tasks/punctuation.py
@@ -16,7 +16,7 @@
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
funasr/tasks/sa_asr.py
@@ -71,7 +71,7 @@
from funasr.models.base_model import FunASRModel
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
funasr/tasks/whisper.py
@@ -76,7 +76,7 @@
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.tokenizer.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
funasr/tokenizer/__init__.py
funasr/tokenizer/abs_tokenizer.py
funasr/tokenizer/build_tokenizer.py
File was renamed from funasr/text/build_tokenizer.py
@@ -3,11 +3,11 @@
from typing import Union
from funasr.text.abs_tokenizer import AbsTokenizer
from funasr.text.char_tokenizer import CharTokenizer
from funasr.text.phoneme_tokenizer import PhonemeTokenizer
from funasr.text.sentencepiece_tokenizer import SentencepiecesTokenizer
from funasr.text.word_tokenizer import WordTokenizer
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.char_tokenizer import CharTokenizer
from funasr.tokenizer.phoneme_tokenizer import PhonemeTokenizer
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
from funasr.tokenizer.word_tokenizer import WordTokenizer
def build_tokenizer(
funasr/tokenizer/char_tokenizer.py
File was renamed from funasr/text/char_tokenizer.py
@@ -5,7 +5,7 @@
import warnings
from funasr.text.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
class CharTokenizer(AbsTokenizer):
funasr/tokenizer/cleaner.py
funasr/tokenizer/korean_cleaner.py
funasr/tokenizer/phoneme_tokenizer.py
File was renamed from funasr/text/phoneme_tokenizer.py
@@ -10,7 +10,7 @@
# import g2p_en
import jamo
from funasr.text.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
g2p_choices = [
@@ -107,7 +107,7 @@
        List[str]: List of phoneme + prosody symbols.
    Examples:
        >>> from funasr.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody
        >>> from funasr.tokenizer.phoneme_tokenizer import pyopenjtalk_g2p_prosody
        >>> pyopenjtalk_g2p_prosody("こんにちは。")
        ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']
funasr/tokenizer/sentencepiece_tokenizer.py
File was renamed from funasr/text/sentencepiece_tokenizer.py
@@ -5,7 +5,7 @@
import sentencepiece as spm
from funasr.text.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
class SentencepiecesTokenizer(AbsTokenizer):
funasr/tokenizer/token_id_converter.py
funasr/tokenizer/word_tokenizer.py
File was renamed from funasr/text/word_tokenizer.py
@@ -5,7 +5,7 @@
import warnings
from funasr.text.abs_tokenizer import AbsTokenizer
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
class WordTokenizer(AbsTokenizer):