speech_asr
2023-04-17 9f90bad3f58c86143e630a9d11d8434adaa62904
update
2个文件已修改
2个文件已添加
1个文件已删除
1435 ■■■■■ 已修改文件
funasr/bin/train.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/iterable_dataset_modelscope.py 349 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/small_datasets/build_loader.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/small_datasets/dataset.py 243 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/small_datasets/preprocessor.py 826 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train.py
@@ -25,6 +25,7 @@
        help="The number of gpus. 0 indicates CPU mode",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
    # ddp related
    parser.add_argument(
funasr/datasets/iterable_dataset_modelscope.py
File was deleted
funasr/datasets/small_datasets/build_loader.py
New file
@@ -0,0 +1,16 @@
import torch
from funasr.datasets.small_datasets.dataset import ESPnetDataset
from funasr.datasets.small_datasets.build_preprocess import build_preprocess
def build_dataloader(args):
    if args.frontend_conf is not None:
        dest_sample_rate = args.frontend_conf["fs"] if (args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
    preprocess_fn = build_preprocess()
    dataset = ESPnetDataset(
        iter_options.data_path_and_name_and_type,
        float_dtype=args.train_dtype,
        preprocess=preprocess_fn,
        max_cache_size=iter_options.max_cache_size,
        max_cache_fd=iter_options.max_cache_fd,
        dest_sample_rate=dest_sample_rate,
    )
funasr/datasets/small_datasets/dataset.py
@@ -1,15 +1,10 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
from abc import ABC
from abc import abstractmethod
import collections
import copy
import functools
import logging
import numbers
import re
from typing import Any
from typing import Callable
from typing import Collection
from typing import Dict
@@ -17,7 +12,6 @@
from typing import Tuple
from typing import Union
import h5py
import humanfriendly
import kaldiio
import numpy as np
@@ -27,10 +21,6 @@
from typeguard import check_return_type
from funasr.fileio.npy_scp import NpyScpReader
from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
from funasr.fileio.rand_gen_dataset import IntRandomGenerateDataset
from funasr.fileio.read_text import load_num_sequence_text
from funasr.fileio.read_text import read_2column_text
from funasr.fileio.sound_scp import SoundScpReader
from funasr.utils.sized_dict import SizedDict
@@ -88,25 +78,6 @@
        return array
class H5FileWrapper:
    def __init__(self, path: str):
        self.path = path
        self.h5_file = h5py.File(path, "r")
    def __repr__(self) -> str:
        return str(self.h5_file)
    def __len__(self) -> int:
        return len(self.h5_file)
    def __iter__(self):
        return iter(self.h5_file)
    def __getitem__(self, key) -> np.ndarray:
        value = self.h5_file[key]
        return value[()]
def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
    # The file is as follows:
    #   utterance_id_A /some/where/a.wav
@@ -127,156 +98,22 @@
    return AdapterForSoundScpReader(loader, float_dtype)
def rand_int_loader(filepath, loader_type):
    # e.g. rand_int_3_10
    try:
        low, high = map(int, loader_type[len("rand_int_") :].split("_"))
    except ValueError:
        raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
    return IntRandomGenerateDataset(filepath, low, high)
DATA_TYPES = {
    "sound": dict(
        func=sound_loader,
        kwargs=["dest_sample_rate","float_dtype"],
        help="Audio format types which supported by sndfile wav, flac, etc."
        "\n\n"
        "   utterance_id_a a.wav\n"
        "   utterance_id_b b.wav\n"
        "   ...",
    ),
    "kaldi_ark": dict(
        func=kaldi_loader,
        kwargs=["max_cache_fd"],
        help="Kaldi-ark file type."
        "\n\n"
        "   utterance_id_A /some/where/a.ark:123\n"
        "   utterance_id_B /some/where/a.ark:456\n"
        "   ...",
    ),
    "npy": dict(
        func=NpyScpReader,
        kwargs=[],
        help="Npy file format."
        "\n\n"
        "   utterance_id_A /some/where/a.npy\n"
        "   utterance_id_B /some/where/b.npy\n"
        "   ...",
    ),
    "text_int": dict(
        func=functools.partial(load_num_sequence_text, loader_type="text_int"),
        kwargs=[],
        help="A text file in which is written a sequence of interger numbers "
        "separated by space."
        "\n\n"
        "   utterance_id_A 12 0 1 3\n"
        "   utterance_id_B 3 3 1\n"
        "   ...",
    ),
    "csv_int": dict(
        func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
        kwargs=[],
        help="A text file in which is written a sequence of interger numbers "
        "separated by comma."
        "\n\n"
        "   utterance_id_A 100,80\n"
        "   utterance_id_B 143,80\n"
        "   ...",
    ),
    "text_float": dict(
        func=functools.partial(load_num_sequence_text, loader_type="text_float"),
        kwargs=[],
        help="A text file in which is written a sequence of float numbers "
        "separated by space."
        "\n\n"
        "   utterance_id_A 12. 3.1 3.4 4.4\n"
        "   utterance_id_B 3. 3.12 1.1\n"
        "   ...",
    ),
    "csv_float": dict(
        func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
        kwargs=[],
        help="A text file in which is written a sequence of float numbers "
        "separated by comma."
        "\n\n"
        "   utterance_id_A 12.,3.1,3.4,4.4\n"
        "   utterance_id_B 3.,3.12,1.1\n"
        "   ...",
    ),
    "text": dict(
        func=read_2column_text,
        kwargs=[],
        help="Return text as is. The text must be converted to ndarray "
        "by 'preprocess'."
        "\n\n"
        "   utterance_id_A hello world\n"
        "   utterance_id_B foo bar\n"
        "   ...",
    ),
    "hdf5": dict(
        func=H5FileWrapper,
        kwargs=[],
        help="A HDF5 file which contains arrays at the first level or the second level."
        "   >>> f = h5py.File('file.h5')\n"
        "   >>> array1 = f['utterance_id_A']\n"
        "   >>> array2 = f['utterance_id_B']\n",
    ),
    "rand_float": dict(
        func=FloatRandomGenerateDataset,
        kwargs=[],
        help="Generate random float-ndarray which has the given shapes "
        "in the file."
        "\n\n"
        "   utterance_id_A 3,4\n"
        "   utterance_id_B 10,4\n"
        "   ...",
    ),
    "rand_int_\\d+_\\d+": dict(
        func=rand_int_loader,
        kwargs=["loader_type"],
        help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
        "shapes in the path. "
        "Give the lower and upper value by the file type. e.g. "
        "rand_int_0_10 -> Generate integers from 0 to 10."
        "\n\n"
        "   utterance_id_A 3,4\n"
        "   utterance_id_B 10,4\n"
        "   ...",
    ),
}
class AbsDataset(Dataset, ABC):
    @abstractmethod
    def has_name(self, name) -> bool:
        raise NotImplementedError
    @abstractmethod
    def names(self) -> Tuple[str, ...]:
        raise NotImplementedError
    @abstractmethod
    def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
        raise NotImplementedError
class ESPnetDataset(AbsDataset):
class ESPnetDataset(Dataset):
    """
        Pytorch Dataset class for FunASR, simplied from ESPnet
        Pytorch Dataset class for FunASR, modified from ESPnet
    """
    def __init__(
        self,
        path_name_type_list: Collection[Tuple[str, str, str]],
        preprocess: Callable[
            [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
        ] = None,
        float_dtype: str = "float32",
        int_dtype: str = "long",
        max_cache_size: Union[float, int, str] = 0.0,
        max_cache_fd: int = 0,
        dest_sample_rate: int = 16000,
            self,
            path_name_type_list: Collection[Tuple[str, str, str]],
            preprocess: Callable[
                [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
            ] = None,
            float_dtype: str = "float32",
            int_dtype: str = "long",
            max_cache_size: Union[float, int, str] = 0.0,
            max_cache_fd: int = 0,
            dest_sample_rate: int = 16000,
    ):
        assert check_argument_types()
        if len(path_name_type_list) == 0:
@@ -304,8 +141,6 @@
            if len(self.loader_dict[name]) == 0:
                raise RuntimeError(f"{path} has no samples")
            # TODO(kamo): Should check consistency of each utt-keys?
        if isinstance(max_cache_size, str):
            max_cache_size = humanfriendly.parse_size(max_cache_size)
        self.max_cache_size = max_cache_size
@@ -315,43 +150,35 @@
            self.cache = None
    def _build_loader(
        self, path: str, loader_type: str
            self, path: str, loader_type: str
    ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
        """Helper function to instantiate Loader.
        Args:
            path:  The file path
            loader_type:  loader_type. sound, npy, text_int, text_float, etc
            loader_type:  loader_type. sound, npy, text, etc
        """
        for key, dic in DATA_TYPES.items():
            # e.g. loader_type="sound"
            # -> return DATA_TYPES["sound"]["func"](path)
            if re.match(key, loader_type):
                kwargs = {}
                for key2 in dic["kwargs"]:
                    if key2 == "loader_type":
                        kwargs["loader_type"] = loader_type
                    elif key2 == "dest_sample_rate" and loader_type=="sound":
                        kwargs["dest_sample_rate"] = self.dest_sample_rate
                    elif key2 == "float_dtype":
                        kwargs["float_dtype"] = self.float_dtype
                    elif key2 == "int_dtype":
                        kwargs["int_dtype"] = self.int_dtype
                    elif key2 == "max_cache_fd":
                        kwargs["max_cache_fd"] = self.max_cache_fd
        if loader_type == "sound":
            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False)
            return AdapterForSoundScpReader(loader, self.float_dtype)
        elif loader_type == "kaldi_ark":
            loader = kaldiio.load_scp(path, max_cache_fd=self.max_cache_fd)
            return AdapterForSoundScpReader(loader, self.float_dtype)
        elif loader_type == "npy":
            return NpyScpReader()
        elif loader_type == "text":
            text_loader = {}
            with open(path, "r", encoding="utf-8") as f:
                for linenum, line in enumerate(f, 1):
                    sps = line.rstrip().split(maxsplit=1)
                    if len(sps) == 1:
                        k, v = sps[0], ""
                    else:
                        raise RuntimeError(f"Not implemented keyword argument: {key2}")
                func = dic["func"]
                try:
                    return func(path, **kwargs)
                except Exception:
                    if hasattr(func, "__name__"):
                        name = func.__name__
                    else:
                        name = str(func)
                    logging.error(f"An error happened with {name}({path})")
                    raise
                        k, v = sps
                    if k in text_loader:
                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
                    text_loader[k] = v
            return text_loader
        else:
            raise RuntimeError(f"Not supported: loader_type={loader_type}")
@@ -392,7 +219,7 @@
                if isinstance(value, (list, tuple)):
                    value = np.array(value)
                if not isinstance(
                    value, (np.ndarray, torch.Tensor, str, numbers.Number)
                        value, (np.ndarray, torch.Tensor, str, numbers.Number)
                ):
                    raise TypeError(
                        f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
funasr/datasets/small_datasets/preprocessor.py
New file
@@ -0,0 +1,826 @@
from abc import ABC
from abc import abstractmethod
from pathlib import Path
from typing import Collection
from typing import Dict
from typing import Iterable
from typing import List
from typing import Union
import numpy as np
import scipy.signal
import soundfile
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.cleaner import TextCleaner
from funasr.text.token_id_converter import TokenIDConverter
class AbsPreprocessor(ABC):
    def __init__(self, train: bool):
        self.train = train
    @abstractmethod
    def __call__(
            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        raise NotImplementedError
def forward_segment(text, dic):
    word_list = []
    i = 0
    while i < len(text):
        longest_word = text[i]
        for j in range(i + 1, len(text) + 1):
            word = text[i:j]
            if word in dic:
                if len(word) > len(longest_word):
                    longest_word = word
        word_list.append(longest_word)
        i += len(longest_word)
    return word_list
def seg_tokenize(txt, seg_dict):
    out_txt = ""
    for word in txt:
        if word in seg_dict:
            out_txt += seg_dict[word] + " "
        else:
            out_txt += "<unk>" + " "
    return out_txt.strip().split()
def seg_tokenize_wo_pattern(txt, seg_dict):
    out_txt = ""
    for word in txt:
        if word in seg_dict:
            out_txt += seg_dict[word] + " "
        else:
            out_txt += "<unk>" + " "
    return out_txt.strip().split()
def framing(
        x,
        frame_length: int = 512,
        frame_shift: int = 256,
        centered: bool = True,
        padded: bool = True,
):
    if x.size == 0:
        raise ValueError("Input array size is zero")
    if frame_length < 1:
        raise ValueError("frame_length must be a positive integer")
    if frame_length > x.shape[-1]:
        raise ValueError("frame_length is greater than input length")
    if 0 >= frame_shift:
        raise ValueError("frame_shift must be greater than 0")
    if centered:
        pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
            (frame_length // 2, frame_length // 2)
        ]
        x = np.pad(x, pad_shape, mode="constant", constant_values=0)
    if padded:
        # Pad to integer number of windowed segments
        # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
        #  with integer nseg
        nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
        pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
        x = np.pad(x, pad_shape, mode="constant", constant_values=0)
    # Created strided array of data segments
    if frame_length == 1 and frame_length == frame_shift:
        result = x[..., None]
    else:
        shape = x.shape[:-1] + (
            (x.shape[-1] - frame_length) // frame_shift + 1,
            frame_length,
        )
        strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
        result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
    return result
def detect_non_silence(
        x: np.ndarray,
        threshold: float = 0.01,
        frame_length: int = 1024,
        frame_shift: int = 512,
        window: str = "boxcar",
) -> np.ndarray:
    """Power based voice activity detection.
    Args:
        x: (Channel, Time)
    >>> x = np.random.randn(1000)
    >>> detect = detect_non_silence(x)
    >>> assert x.shape == detect.shape
    >>> assert detect.dtype == np.bool
    """
    if x.shape[-1] < frame_length:
        return np.full(x.shape, fill_value=True, dtype=np.bool)
    if x.dtype.kind == "i":
        x = x.astype(np.float64)
    # framed_w: (C, T, F)
    framed_w = framing(
        x,
        frame_length=frame_length,
        frame_shift=frame_shift,
        centered=False,
        padded=True,
    )
    framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
    # power: (C, T)
    power = (framed_w ** 2).mean(axis=-1)
    # mean_power: (C, 1)
    mean_power = np.mean(power, axis=-1, keepdims=True)
    if np.all(mean_power == 0):
        return np.full(x.shape, fill_value=True, dtype=np.bool)
    # detect_frames: (C, T)
    detect_frames = power / mean_power > threshold
    # detects: (C, T, F)
    detects = np.broadcast_to(
        detect_frames[..., None], detect_frames.shape + (frame_shift,)
    )
    # detects: (C, TF)
    detects = detects.reshape(*detect_frames.shape[:-1], -1)
    # detects: (C, TF)
    return np.pad(
        detects,
        [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
        mode="edge",
    )
class CommonPreprocessor(AbsPreprocessor):
    def __init__(
            self,
            train: bool,
            token_type: str = None,
            token_list: Union[Path, str, Iterable[str]] = None,
            bpemodel: Union[Path, str, Iterable[str]] = None,
            text_cleaner: Collection[str] = None,
            g2p_type: str = None,
            unk_symbol: str = "<unk>",
            space_symbol: str = "<space>",
            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
            delimiter: str = None,
            rir_scp: str = None,
            rir_apply_prob: float = 1.0,
            noise_scp: str = None,
            noise_apply_prob: float = 1.0,
            noise_db_range: str = "3_10",
            speech_volume_normalize: float = None,
            speech_name: str = "speech",
            text_name: str = "text",
            split_with_space: bool = False,
            seg_dict_file: str = None,
    ):
        super().__init__(train)
        self.train = train
        self.speech_name = speech_name
        self.text_name = text_name
        self.speech_volume_normalize = speech_volume_normalize
        self.rir_apply_prob = rir_apply_prob
        self.noise_apply_prob = noise_apply_prob
        self.split_with_space = split_with_space
        self.seg_dict = None
        if seg_dict_file is not None:
            self.seg_dict = {}
            with open(seg_dict_file) as f:
                lines = f.readlines()
            for line in lines:
                s = line.strip().split()
                key = s[0]
                value = s[1:]
                self.seg_dict[key] = " ".join(value)
        if token_type is not None:
            if token_list is None:
                raise ValueError("token_list is required if token_type is not None")
            self.text_cleaner = TextCleaner(text_cleaner)
            self.tokenizer = build_tokenizer(
                token_type=token_type,
                bpemodel=bpemodel,
                delimiter=delimiter,
                space_symbol=space_symbol,
                non_linguistic_symbols=non_linguistic_symbols,
                g2p_type=g2p_type,
            )
            self.token_id_converter = TokenIDConverter(
                token_list=token_list,
                unk_symbol=unk_symbol,
            )
        else:
            self.text_cleaner = None
            self.tokenizer = None
            self.token_id_converter = None
        if train and rir_scp is not None:
            self.rirs = []
            with open(rir_scp, "r", encoding="utf-8") as f:
                for line in f:
                    sps = line.strip().split(None, 1)
                    if len(sps) == 1:
                        self.rirs.append(sps[0])
                    else:
                        self.rirs.append(sps[1])
        else:
            self.rirs = None
        if train and noise_scp is not None:
            self.noises = []
            with open(noise_scp, "r", encoding="utf-8") as f:
                for line in f:
                    sps = line.strip().split(None, 1)
                    if len(sps) == 1:
                        self.noises.append(sps[0])
                    else:
                        self.noises.append(sps[1])
            sps = noise_db_range.split("_")
            if len(sps) == 1:
                self.noise_db_low, self.noise_db_high = float(sps[0])
            elif len(sps) == 2:
                self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
            else:
                raise ValueError(
                    "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
                )
        else:
            self.noises = None
    def _speech_process(
            self, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, Union[str, np.ndarray]]:
        assert check_argument_types()
        if self.speech_name in data:
            if self.train and (self.rirs is not None or self.noises is not None):
                speech = data[self.speech_name]
                nsamples = len(speech)
                # speech: (Nmic, Time)
                if speech.ndim == 1:
                    speech = speech[None, :]
                else:
                    speech = speech.T
                # Calc power on non shlence region
                power = (speech[detect_non_silence(speech)] ** 2).mean()
                # 1. Convolve RIR
                if self.rirs is not None and self.rir_apply_prob >= np.random.random():
                    rir_path = np.random.choice(self.rirs)
                    if rir_path is not None:
                        rir, _ = soundfile.read(
                            rir_path, dtype=np.float64, always_2d=True
                        )
                        # rir: (Nmic, Time)
                        rir = rir.T
                        # speech: (Nmic, Time)
                        # Note that this operation doesn't change the signal length
                        speech = scipy.signal.convolve(speech, rir, mode="full")[
                                 :, : speech.shape[1]
                                 ]
                        # Reverse mean power to the original power
                        power2 = (speech[detect_non_silence(speech)] ** 2).mean()
                        speech = np.sqrt(power / max(power2, 1e-10)) * speech
                # 2. Add Noise
                if (
                        self.noises is not None
                        and self.noise_apply_prob >= np.random.random()
                ):
                    noise_path = np.random.choice(self.noises)
                    if noise_path is not None:
                        noise_db = np.random.uniform(
                            self.noise_db_low, self.noise_db_high
                        )
                        with soundfile.SoundFile(noise_path) as f:
                            if f.frames == nsamples:
                                noise = f.read(dtype=np.float64, always_2d=True)
                            elif f.frames < nsamples:
                                offset = np.random.randint(0, nsamples - f.frames)
                                # noise: (Time, Nmic)
                                noise = f.read(dtype=np.float64, always_2d=True)
                                # Repeat noise
                                noise = np.pad(
                                    noise,
                                    [(offset, nsamples - f.frames - offset), (0, 0)],
                                    mode="wrap",
                                )
                            else:
                                offset = np.random.randint(0, f.frames - nsamples)
                                f.seek(offset)
                                # noise: (Time, Nmic)
                                noise = f.read(
                                    nsamples, dtype=np.float64, always_2d=True
                                )
                                if len(noise) != nsamples:
                                    raise RuntimeError(f"Something wrong: {noise_path}")
                        # noise: (Nmic, Time)
                        noise = noise.T
                        noise_power = (noise ** 2).mean()
                        scale = (
                                10 ** (-noise_db / 20)
                                * np.sqrt(power)
                                / np.sqrt(max(noise_power, 1e-10))
                        )
                        speech = speech + scale * noise
                speech = speech.T
                ma = np.max(np.abs(speech))
                if ma > 1.0:
                    speech /= ma
                data[self.speech_name] = speech
            if self.speech_volume_normalize is not None:
                speech = data[self.speech_name]
                ma = np.max(np.abs(speech))
                data[self.speech_name] = speech * self.speech_volume_normalize / ma
        assert check_return_type(data)
        return data
    def _text_process(
            self, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        if self.text_name in data and self.tokenizer is not None:
            text = data[self.text_name]
            text = self.text_cleaner(text)
            if self.split_with_space:
                tokens = text.strip().split(" ")
                if self.seg_dict is not None:
                    tokens = forward_segment("".join(tokens), self.seg_dict)
                    tokens = seg_tokenize(tokens, self.seg_dict)
            else:
                tokens = self.tokenizer.text2tokens(text)
            text_ints = self.token_id_converter.tokens2ids(tokens)
            data[self.text_name] = np.array(text_ints, dtype=np.int64)
        assert check_return_type(data)
        return data
    def __call__(
            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        assert check_argument_types()
        data = self._speech_process(data)
        data = self._text_process(data)
        return data
## FIXME
class LMPreprocessor(CommonPreprocessor):
    def __init__(
            self,
            train: bool,
            token_type: str = None,
            token_list: Union[Path, str, Iterable[str]] = None,
            bpemodel: Union[Path, str, Iterable[str]] = None,
            text_cleaner: Collection[str] = None,
            g2p_type: str = None,
            unk_symbol: str = "<unk>",
            space_symbol: str = "<space>",
            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
            delimiter: str = None,
            rir_scp: str = None,
            rir_apply_prob: float = 1.0,
            noise_scp: str = None,
            noise_apply_prob: float = 1.0,
            noise_db_range: str = "3_10",
            speech_volume_normalize: float = None,
            speech_name: str = "speech",
            text_name: str = "text",
            split_with_space: bool = False,
            seg_dict_file: str = None,
    ):
        super().__init__(train,
                         token_type,
                         token_list,
                         bpemodel,
                         text_cleaner,
                         g2p_type,
                         unk_symbol,
                         space_symbol,
                         non_linguistic_symbols,
                         delimiter,
                         rir_scp,
                         rir_apply_prob,
                         noise_scp,
                         noise_apply_prob,
                         noise_db_range,
                         speech_volume_normalize,
                         speech_name,
                         text_name,
                         split_with_space,
                         seg_dict_file,
                         )
    def _text_process(
            self, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        if self.text_name in data and self.tokenizer is not None:
            text = data[self.text_name]
            text = self.text_cleaner(text)
            if self.split_with_space:
                tokens = text.strip().split(" ")
                if self.seg_dict is not None:
                    tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
            else:
                tokens = self.tokenizer.text2tokens(text)
            text_ints = self.token_id_converter.tokens2ids(tokens)
            data[self.text_name] = np.array(text_ints, dtype=np.int64)
        assert check_return_type(data)
        return data
class CommonPreprocessor_multi(AbsPreprocessor):
    def __init__(
            self,
            train: bool,
            token_type: str = None,
            token_list: Union[Path, str, Iterable[str]] = None,
            bpemodel: Union[Path, str, Iterable[str]] = None,
            text_cleaner: Collection[str] = None,
            g2p_type: str = None,
            unk_symbol: str = "<unk>",
            space_symbol: str = "<space>",
            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
            delimiter: str = None,
            speech_name: str = "speech",
            text_name: List[str] = ["text"],
    ):
        super().__init__(train)
        self.train = train
        self.speech_name = speech_name
        self.text_name = text_name
        if token_type is not None:
            if token_list is None:
                raise ValueError("token_list is required if token_type is not None")
            self.text_cleaner = TextCleaner(text_cleaner)
            self.tokenizer = build_tokenizer(
                token_type=token_type,
                bpemodel=bpemodel,
                delimiter=delimiter,
                space_symbol=space_symbol,
                non_linguistic_symbols=non_linguistic_symbols,
                g2p_type=g2p_type,
            )
            self.token_id_converter = TokenIDConverter(
                token_list=token_list,
                unk_symbol=unk_symbol,
            )
        else:
            self.text_cleaner = None
            self.tokenizer = None
            self.token_id_converter = None
    def _text_process(
            self, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        for text_n in self.text_name:
            if text_n in data and self.tokenizer is not None:
                text = data[text_n]
                text = self.text_cleaner(text)
                tokens = self.tokenizer.text2tokens(text)
                text_ints = self.token_id_converter.tokens2ids(tokens)
                data[text_n] = np.array(text_ints, dtype=np.int64)
        assert check_return_type(data)
        return data
    def __call__(
            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        assert check_argument_types()
        if self.speech_name in data:
            # Nothing now: candidates:
            # - STFT
            # - Fbank
            # - CMVN
            # - Data augmentation
            pass
        data = self._text_process(data)
        return data
class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
    def __init__(
            self,
            train: bool,
            token_type: List[str] = [None],
            token_list: List[Union[Path, str, Iterable[str]]] = [None],
            bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
            text_cleaner: Collection[str] = None,
            g2p_type: str = None,
            unk_symbol: str = "<unk>",
            space_symbol: str = "<space>",
            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
            delimiter: str = None,
            rir_scp: str = None,
            rir_apply_prob: float = 1.0,
            noise_scp: str = None,
            noise_apply_prob: float = 1.0,
            noise_db_range: str = "3_10",
            speech_volume_normalize: float = None,
            speech_name: str = "speech",
            text_name: List[str] = ["text"],
    ):
        # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
        super().__init__(
            train=train,
            token_type=token_type[0],
            token_list=token_list[0],
            bpemodel=bpemodel[0],
            text_cleaner=text_cleaner,
            g2p_type=g2p_type,
            unk_symbol=unk_symbol,
            space_symbol=space_symbol,
            non_linguistic_symbols=non_linguistic_symbols,
            delimiter=delimiter,
            speech_name=speech_name,
            text_name=text_name[0],
            rir_scp=rir_scp,
            rir_apply_prob=rir_apply_prob,
            noise_scp=noise_scp,
            noise_apply_prob=noise_apply_prob,
            noise_db_range=noise_db_range,
            speech_volume_normalize=speech_volume_normalize,
        )
        assert (
                len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
        ), "token_type, token_list, bpemodel, or processing text_name mismatched"
        self.num_tokenizer = len(token_type)
        self.tokenizer = []
        self.token_id_converter = []
        for i in range(self.num_tokenizer):
            if token_type[i] is not None:
                if token_list[i] is None:
                    raise ValueError("token_list is required if token_type is not None")
                self.tokenizer.append(
                    build_tokenizer(
                        token_type=token_type[i],
                        bpemodel=bpemodel[i],
                        delimiter=delimiter,
                        space_symbol=space_symbol,
                        non_linguistic_symbols=non_linguistic_symbols,
                        g2p_type=g2p_type,
                    )
                )
                self.token_id_converter.append(
                    TokenIDConverter(
                        token_list=token_list[i],
                        unk_symbol=unk_symbol,
                    )
                )
            else:
                self.tokenizer.append(None)
                self.token_id_converter.append(None)
        self.text_cleaner = TextCleaner(text_cleaner)
        self.text_name = text_name  # override the text_name from CommonPreprocessor
    def _text_process(
            self, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        for i in range(self.num_tokenizer):
            text_name = self.text_name[i]
            if text_name in data and self.tokenizer[i] is not None:
                text = data[text_name]
                text = self.text_cleaner(text)
                tokens = self.tokenizer[i].text2tokens(text)
                text_ints = self.token_id_converter[i].tokens2ids(tokens)
                data[text_name] = np.array(text_ints, dtype=np.int64)
        assert check_return_type(data)
        return data
class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
    def __init__(
            self,
            train: bool,
            token_type: str = None,
            token_list: Union[Path, str, Iterable[str]] = None,
            bpemodel: Union[Path, str, Iterable[str]] = None,
            text_cleaner: Collection[str] = None,
            g2p_type: str = None,
            unk_symbol: str = "<unk>",
            space_symbol: str = "<space>",
            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
            delimiter: str = None,
            rir_scp: str = None,
            rir_apply_prob: float = 1.0,
            noise_scp: str = None,
            noise_apply_prob: float = 1.0,
            noise_db_range: str = "3_10",
            speech_volume_normalize: float = None,
            speech_name: str = "speech",
            text_name: str = "text",
            split_text_name: str = "split_text",
            split_with_space: bool = False,
            seg_dict_file: str = None,
    ):
        super().__init__(
            train=train,
            # Force to use word.
            token_type="word",
            token_list=token_list,
            bpemodel=bpemodel,
            text_cleaner=text_cleaner,
            g2p_type=g2p_type,
            unk_symbol=unk_symbol,
            space_symbol=space_symbol,
            non_linguistic_symbols=non_linguistic_symbols,
            delimiter=delimiter,
            speech_name=speech_name,
            text_name=text_name,
            rir_scp=rir_scp,
            rir_apply_prob=rir_apply_prob,
            noise_scp=noise_scp,
            noise_apply_prob=noise_apply_prob,
            noise_db_range=noise_db_range,
            speech_volume_normalize=speech_volume_normalize,
            split_with_space=split_with_space,
            seg_dict_file=seg_dict_file,
        )
        # The data field name for split text.
        self.split_text_name = split_text_name
    @classmethod
    def split_words(cls, text: str):
        words = []
        segs = text.split()
        for seg in segs:
            # There is no space in seg.
            current_word = ""
            for c in seg:
                if len(c.encode()) == 1:
                    # This is an ASCII char.
                    current_word += c
                else:
                    # This is a Chinese char.
                    if len(current_word) > 0:
                        words.append(current_word)
                        current_word = ""
                    words.append(c)
            if len(current_word) > 0:
                words.append(current_word)
        return words
    def __call__(
            self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
    ) -> Dict[str, Union[list, np.ndarray]]:
        assert check_argument_types()
        # Split words.
        if isinstance(data[self.text_name], str):
            split_text = self.split_words(data[self.text_name])
        else:
            split_text = data[self.text_name]
        data[self.text_name] = " ".join(split_text)
        data = self._speech_process(data)
        data = self._text_process(data)
        data[self.split_text_name] = split_text
        return data
    def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
        result = data[self.split_text_name]
        del data[self.split_text_name]
        return result
class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
    def __init__(
            self,
            train: bool,
            token_type: List[str] = [None],
            token_list: List[Union[Path, str, Iterable[str]]] = [None],
            bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
            text_cleaner: Collection[str] = None,
            g2p_type: str = None,
            unk_symbol: str = "<unk>",
            space_symbol: str = "<space>",
            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
            delimiter: str = None,
            rir_scp: str = None,
            rir_apply_prob: float = 1.0,
            noise_scp: str = None,
            noise_apply_prob: float = 1.0,
            noise_db_range: str = "3_10",
            speech_volume_normalize: float = None,
            speech_name: str = "speech",
            text_name: List[str] = ["text"],
            vad_name: str = "vad_indexes",
    ):
        # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
        super().__init__(
            train=train,
            token_type=token_type[0],
            token_list=token_list[0],
            bpemodel=bpemodel[0],
            text_cleaner=text_cleaner,
            g2p_type=g2p_type,
            unk_symbol=unk_symbol,
            space_symbol=space_symbol,
            non_linguistic_symbols=non_linguistic_symbols,
            delimiter=delimiter,
            speech_name=speech_name,
            text_name=text_name[0],
            rir_scp=rir_scp,
            rir_apply_prob=rir_apply_prob,
            noise_scp=noise_scp,
            noise_apply_prob=noise_apply_prob,
            noise_db_range=noise_db_range,
            speech_volume_normalize=speech_volume_normalize,
        )
        assert (
                len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
        ), "token_type, token_list, bpemodel, or processing text_name mismatched"
        self.num_tokenizer = len(token_type)
        self.tokenizer = []
        self.token_id_converter = []
        for i in range(self.num_tokenizer):
            if token_type[i] is not None:
                if token_list[i] is None:
                    raise ValueError("token_list is required if token_type is not None")
                self.tokenizer.append(
                    build_tokenizer(
                        token_type=token_type[i],
                        bpemodel=bpemodel[i],
                        delimiter=delimiter,
                        space_symbol=space_symbol,
                        non_linguistic_symbols=non_linguistic_symbols,
                        g2p_type=g2p_type,
                    )
                )
                self.token_id_converter.append(
                    TokenIDConverter(
                        token_list=token_list[i],
                        unk_symbol=unk_symbol,
                    )
                )
            else:
                self.tokenizer.append(None)
                self.token_id_converter.append(None)
        self.text_cleaner = TextCleaner(text_cleaner)
        self.text_name = text_name  # override the text_name from CommonPreprocessor
        self.vad_name = vad_name
    def _text_process(
            self, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        for i in range(self.num_tokenizer):
            text_name = self.text_name[i]
            if text_name in data and self.tokenizer[i] is not None:
                text = data[text_name]
                text = self.text_cleaner(text)
                tokens = self.tokenizer[i].text2tokens(text)
                if "vad:" in tokens[-1]:
                    vad = tokens[-1][4:]
                    tokens = tokens[:-1]
                    if len(vad) == 0:
                        vad = -1
                    else:
                        vad = int(vad)
                    data[self.vad_name] = np.array([vad], dtype=np.int64)
                text_ints = self.token_id_converter[i].tokens2ids(tokens)
                data[text_name] = np.array(text_ints, dtype=np.int64)
def split_to_mini_sentence(words: list, word_limit: int = 20):
    assert word_limit > 1
    if len(words) <= word_limit:
        return [words]
    sentences = []
    length = len(words)
    sentence_len = length // word_limit
    for i in range(sentence_len):
        sentences.append(words[i * word_limit:(i + 1) * word_limit])
    if length % word_limit > 0:
        sentences.append(words[sentence_len * word_limit:])
    return sentences
def build_preprocess(args):
    if args.task_name == "asr":
        pass
    else:
        raise ValueError(f"Not supported task={args.task_name}")