speech_asr
2023-03-21 8314c5f17eff473e08dca1729c8a1e62290c7866
update
3个文件已修改
53 ■■■■ 已修改文件
funasr/datasets/large_datasets/build_dataloader.py 44 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/dataset.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/utils/tokenize.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/build_dataloader.py
@@ -1,10 +1,16 @@
import logging
from pathlib import Path
from typing import Iterable
from typing import List
from typing import Union
import yaml
import sentencepiece as spm
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from funasr.datasets.large_datasets.dataset import Dataset
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.text.abs_tokenizer import AbsTokenizer
def read_symbol_table(symbol_table_file):
@@ -21,6 +27,7 @@
            symbol_table[char] = i
    return symbol_table
def load_seg_dict(seg_dict_file):
    seg_dict = {}
    assert isinstance(seg_dict_file, str)
@@ -33,8 +40,33 @@
            seg_dict[key] = " ".join(value)
    return seg_dict
class SentencepiecesTokenizer(AbsTokenizer):
    def __init__(self, model: Union[Path, str]):
        assert check_argument_types()
        self.model = str(model)
        self.sp = None
    def __repr__(self):
        return f'{self.__class__.__name__}(model="{self.model}")'
    def _build_sentence_piece_processor(self):
        if self.sp is None:
            self.sp = spm.SentencePieceProcessor()
            self.sp.load(self.model)
    def text2tokens(self, line: str) -> List[str]:
        self._build_sentence_piece_processor()
        return self.sp.EncodeAsPieces(line)
    def tokens2text(self, tokens: Iterable[str]) -> str:
        self._build_sentence_piece_processor()
        return self.sp.DecodePieces(list(tokens))
class ArkDataLoader(AbsIterFactory):
    def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None, mode="train"):
    def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None,
                 bpemodel_file=None, mode="train"):
        symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
        if seg_dict_file is not None:
            seg_dict = load_seg_dict(seg_dict_file)
@@ -48,7 +80,11 @@
        self.frontend_conf = frontend_conf
        logging.info("dataloader config: {}".format(self.dataset_conf))
        batch_mode = self.dataset_conf.get("batch_mode", "padding")
        self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict,
        if bpemodel_file is not None:
            bpe_tokenizer = SentencepiecesTokenizer(bpemodel_file)
        else:
            bpe_tokenizer = None
        self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
                               self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
    def build_iter(self, epoch, shuffle=True):
funasr/datasets/large_datasets/dataset.py
@@ -158,6 +158,7 @@
            dict,
            seg_dict,
            punc_dict,
            bpe_tokenizer,
            conf,
            frontend_conf,
            mode="train",
@@ -173,7 +174,7 @@
    dataset = FilterIterDataPipe(dataset, fn=filter_fn)
    if "text" in data_names:
        vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict}
        vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer}
        tokenize_fn = partial(tokenize, **vocab)
        dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
funasr/datasets/large_datasets/utils/tokenize.py
@@ -28,13 +28,17 @@
def tokenize(data,
             vocab=None,
             seg_dict=None,
             punc_dict=None):
             punc_dict=None,
             bpe_tokenizer=None):
    assert "text" in data
    assert isinstance(vocab, dict)
    text = data["text"]
    token = []
    vad = -2
    if bpe_tokenizer is not None:
        text = bpe_tokenizer.text2tokens(text)
    if seg_dict is not None:
        assert isinstance(seg_dict, dict)
        txt = forward_segment("".join(text).lower(), seg_dict)