| | |
| | | # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | from abc import ABC |
| | | from abc import abstractmethod |
| | | import collections |
| | | import copy |
| | | import functools |
| | | import logging |
| | | import numbers |
| | | import re |
| | | from typing import Any |
| | | from typing import Callable |
| | | from typing import Collection |
| | | from typing import Dict |
| | | from typing import Mapping |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Union, List, Tuple |
| | | |
| | | import h5py |
| | | import humanfriendly |
| | | import kaldiio |
| | | import numpy as np |
| | | import torch |
| | | from torch.utils.data.dataset import Dataset |
| | | from typeguard import check_argument_types |
| | | from typeguard import check_return_type |
| | | |
| | | from funasr.fileio.npy_scp import NpyScpReader |
| | | from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset |
| | | from funasr.fileio.rand_gen_dataset import IntRandomGenerateDataset |
| | | from funasr.fileio.read_text import load_num_sequence_text |
| | | from funasr.fileio.read_text import read_2column_text |
| | | from funasr.fileio.sound_scp import SoundScpReader |
| | | from funasr.utils.sized_dict import SizedDict |
| | | |
| | | |
| | | class AdapterForSoundScpReader(collections.abc.Mapping): |
| | | def __init__(self, loader, dtype=None): |
| | | assert check_argument_types() |
| | | self.loader = loader |
| | | self.dtype = dtype |
| | | self.rate = None |
| | |
| | | return array |
| | | |
| | | |
| | | class H5FileWrapper: |
| | | def __init__(self, path: str): |
| | | self.path = path |
| | | self.h5_file = h5py.File(path, "r") |
| | | |
| | | def __repr__(self) -> str: |
| | | return str(self.h5_file) |
| | | |
| | | def __len__(self) -> int: |
| | | return len(self.h5_file) |
| | | |
| | | def __iter__(self): |
| | | return iter(self.h5_file) |
| | | |
| | | def __getitem__(self, key) -> np.ndarray: |
| | | value = self.h5_file[key] |
| | | return value[()] |
| | | |
| | | |
| | | def sound_loader(path, dest_sample_rate=16000, float_dtype=None): |
| | | # The file is as follows: |
| | | # utterance_id_A /some/where/a.wav |
| | |
| | | return AdapterForSoundScpReader(loader, float_dtype) |
| | | |
| | | |
| | | def rand_int_loader(filepath, loader_type): |
| | | # e.g. rand_int_3_10 |
| | | try: |
| | | low, high = map(int, loader_type[len("rand_int_") :].split("_")) |
| | | except ValueError: |
| | | raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}") |
| | | return IntRandomGenerateDataset(filepath, low, high) |
| | | |
| | | |
| | | DATA_TYPES = { |
| | | "sound": dict( |
| | | func=sound_loader, |
| | | kwargs=["dest_sample_rate","float_dtype"], |
| | | help="Audio format types which supported by sndfile wav, flac, etc." |
| | | "\n\n" |
| | | " utterance_id_a a.wav\n" |
| | | " utterance_id_b b.wav\n" |
| | | " ...", |
| | | ), |
| | | "kaldi_ark": dict( |
| | | func=kaldi_loader, |
| | | kwargs=["max_cache_fd"], |
| | | help="Kaldi-ark file type." |
| | | "\n\n" |
| | | " utterance_id_A /some/where/a.ark:123\n" |
| | | " utterance_id_B /some/where/a.ark:456\n" |
| | | " ...", |
| | | ), |
| | | "npy": dict( |
| | | func=NpyScpReader, |
| | | kwargs=[], |
| | | help="Npy file format." |
| | | "\n\n" |
| | | " utterance_id_A /some/where/a.npy\n" |
| | | " utterance_id_B /some/where/b.npy\n" |
| | | " ...", |
| | | ), |
| | | "text_int": dict( |
| | | func=functools.partial(load_num_sequence_text, loader_type="text_int"), |
| | | kwargs=[], |
| | | help="A text file in which is written a sequence of interger numbers " |
| | | "separated by space." |
| | | "\n\n" |
| | | " utterance_id_A 12 0 1 3\n" |
| | | " utterance_id_B 3 3 1\n" |
| | | " ...", |
| | | ), |
| | | "csv_int": dict( |
| | | func=functools.partial(load_num_sequence_text, loader_type="csv_int"), |
| | | kwargs=[], |
| | | help="A text file in which is written a sequence of interger numbers " |
| | | "separated by comma." |
| | | "\n\n" |
| | | " utterance_id_A 100,80\n" |
| | | " utterance_id_B 143,80\n" |
| | | " ...", |
| | | ), |
| | | "text_float": dict( |
| | | func=functools.partial(load_num_sequence_text, loader_type="text_float"), |
| | | kwargs=[], |
| | | help="A text file in which is written a sequence of float numbers " |
| | | "separated by space." |
| | | "\n\n" |
| | | " utterance_id_A 12. 3.1 3.4 4.4\n" |
| | | " utterance_id_B 3. 3.12 1.1\n" |
| | | " ...", |
| | | ), |
| | | "csv_float": dict( |
| | | func=functools.partial(load_num_sequence_text, loader_type="csv_float"), |
| | | kwargs=[], |
| | | help="A text file in which is written a sequence of float numbers " |
| | | "separated by comma." |
| | | "\n\n" |
| | | " utterance_id_A 12.,3.1,3.4,4.4\n" |
| | | " utterance_id_B 3.,3.12,1.1\n" |
| | | " ...", |
| | | ), |
| | | "text": dict( |
| | | func=read_2column_text, |
| | | kwargs=[], |
| | | help="Return text as is. The text must be converted to ndarray " |
| | | "by 'preprocess'." |
| | | "\n\n" |
| | | " utterance_id_A hello world\n" |
| | | " utterance_id_B foo bar\n" |
| | | " ...", |
| | | ), |
| | | "hdf5": dict( |
| | | func=H5FileWrapper, |
| | | kwargs=[], |
| | | help="A HDF5 file which contains arrays at the first level or the second level." |
| | | " >>> f = h5py.File('file.h5')\n" |
| | | " >>> array1 = f['utterance_id_A']\n" |
| | | " >>> array2 = f['utterance_id_B']\n", |
| | | ), |
| | | "rand_float": dict( |
| | | func=FloatRandomGenerateDataset, |
| | | kwargs=[], |
| | | help="Generate random float-ndarray which has the given shapes " |
| | | "in the file." |
| | | "\n\n" |
| | | " utterance_id_A 3,4\n" |
| | | " utterance_id_B 10,4\n" |
| | | " ...", |
| | | ), |
| | | "rand_int_\\d+_\\d+": dict( |
| | | func=rand_int_loader, |
| | | kwargs=["loader_type"], |
| | | help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given " |
| | | "shapes in the path. " |
| | | "Give the lower and upper value by the file type. e.g. " |
| | | "rand_int_0_10 -> Generate integers from 0 to 10." |
| | | "\n\n" |
| | | " utterance_id_A 3,4\n" |
| | | " utterance_id_B 10,4\n" |
| | | " ...", |
| | | ), |
| | | } |
| | | |
| | | |
| | | class AbsDataset(Dataset, ABC): |
| | | @abstractmethod |
| | | def has_name(self, name) -> bool: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def names(self) -> Tuple[str, ...]: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]: |
| | | raise NotImplementedError |
| | | |
| | | |
| | | class ESPnetDataset(AbsDataset): |
| | | class ESPnetDataset(Dataset): |
| | | """ |
| | | Pytorch Dataset class for FunASR, simplied from ESPnet |
| | | Pytorch Dataset class for FunASR, modified from ESPnet |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | path_name_type_list: Collection[Tuple[str, str, str]], |
| | | preprocess: Callable[ |
| | | [str, Dict[str, np.ndarray]], Dict[str, np.ndarray] |
| | | ] = None, |
| | | float_dtype: str = "float32", |
| | | int_dtype: str = "long", |
| | | max_cache_size: Union[float, int, str] = 0.0, |
| | | max_cache_fd: int = 0, |
| | | dest_sample_rate: int = 16000, |
| | | self, |
| | | path_name_type_list: Collection[Tuple[str, str, str]], |
| | | preprocess: Callable[ |
| | | [str, Dict[str, np.ndarray]], Dict[str, np.ndarray] |
| | | ] = None, |
| | | float_dtype: str = "float32", |
| | | int_dtype: str = "long", |
| | | dest_sample_rate: int = 16000, |
| | | speed_perturb: Union[list, tuple] = None, |
| | | mode: str = "train", |
| | | ): |
| | | assert check_argument_types() |
| | | if len(path_name_type_list) == 0: |
| | | raise ValueError( |
| | | '1 or more elements are required for "path_name_type_list"' |
| | |
| | | |
| | | self.float_dtype = float_dtype |
| | | self.int_dtype = int_dtype |
| | | self.max_cache_fd = max_cache_fd |
| | | self.dest_sample_rate = dest_sample_rate |
| | | self.speed_perturb = speed_perturb |
| | | self.mode = mode |
| | | if self.speed_perturb is not None: |
| | | logging.info("Using speed_perturb: {}".format(speed_perturb)) |
| | | |
| | | self.loader_dict = {} |
| | | self.debug_info = {} |
| | |
| | | if len(self.loader_dict[name]) == 0: |
| | | raise RuntimeError(f"{path} has no samples") |
| | | |
| | | # TODO(kamo): Should check consistency of each utt-keys? |
| | | |
| | | if isinstance(max_cache_size, str): |
| | | max_cache_size = humanfriendly.parse_size(max_cache_size) |
| | | self.max_cache_size = max_cache_size |
| | | if max_cache_size > 0: |
| | | self.cache = SizedDict(shared=True) |
| | | else: |
| | | self.cache = None |
| | | |
| | | def _build_loader( |
| | | self, path: str, loader_type: str |
| | | ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]: |
| | | self, path: str, loader_type: str |
| | | ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, List[int], numbers.Number]]: |
| | | """Helper function to instantiate Loader. |
| | | |
| | | Args: |
| | | path: The file path |
| | | loader_type: loader_type. sound, npy, text_int, text_float, etc |
| | | loader_type: loader_type. sound, npy, text, etc |
| | | """ |
| | | for key, dic in DATA_TYPES.items(): |
| | | # e.g. loader_type="sound" |
| | | # -> return DATA_TYPES["sound"]["func"](path) |
| | | if re.match(key, loader_type): |
| | | kwargs = {} |
| | | for key2 in dic["kwargs"]: |
| | | if key2 == "loader_type": |
| | | kwargs["loader_type"] = loader_type |
| | | elif key2 == "dest_sample_rate" and loader_type=="sound": |
| | | kwargs["dest_sample_rate"] = self.dest_sample_rate |
| | | elif key2 == "float_dtype": |
| | | kwargs["float_dtype"] = self.float_dtype |
| | | elif key2 == "int_dtype": |
| | | kwargs["int_dtype"] = self.int_dtype |
| | | elif key2 == "max_cache_fd": |
| | | kwargs["max_cache_fd"] = self.max_cache_fd |
| | | if loader_type == "sound": |
| | | speed_perturb = self.speed_perturb if self.mode == "train" else None |
| | | loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False, |
| | | speed_perturb=speed_perturb) |
| | | return AdapterForSoundScpReader(loader, self.float_dtype) |
| | | elif loader_type == "kaldi_ark": |
| | | loader = kaldiio.load_scp(path) |
| | | return AdapterForSoundScpReader(loader, self.float_dtype) |
| | | elif loader_type == "npy": |
| | | return NpyScpReader(path) |
| | | elif loader_type == "text": |
| | | text_loader = {} |
| | | with open(path, "r", encoding="utf-8") as f: |
| | | for linenum, line in enumerate(f, 1): |
| | | sps = line.rstrip().split(maxsplit=1) |
| | | if len(sps) == 1: |
| | | k, v = sps[0], "" |
| | | else: |
| | | raise RuntimeError(f"Not implemented keyword argument: {key2}") |
| | | |
| | | func = dic["func"] |
| | | try: |
| | | return func(path, **kwargs) |
| | | except Exception: |
| | | if hasattr(func, "__name__"): |
| | | name = func.__name__ |
| | | k, v = sps |
| | | if k in text_loader: |
| | | raise RuntimeError(f"{k} is duplicated ({path}:{linenum})") |
| | | text_loader[k] = v |
| | | return text_loader |
| | | elif loader_type == "text_int": |
| | | text_int_loader = {} |
| | | with open(path, "r", encoding="utf-8") as f: |
| | | for linenum, line in enumerate(f, 1): |
| | | sps = line.rstrip().split(maxsplit=1) |
| | | if len(sps) == 1: |
| | | k, v = sps[0], "" |
| | | else: |
| | | name = str(func) |
| | | logging.error(f"An error happened with {name}({path})") |
| | | raise |
| | | k, v = sps |
| | | if k in text_int_loader: |
| | | raise RuntimeError(f"{k} is duplicated ({path}:{linenum})") |
| | | text_int_loader[k] = [int(i) for i in v.split()] |
| | | return text_int_loader |
| | | else: |
| | | raise RuntimeError(f"Not supported: loader_type={loader_type}") |
| | | |
| | |
| | | return _mes |
| | | |
| | | def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]: |
| | | assert check_argument_types() |
| | | |
| | | # Change integer-id to string-id |
| | | if isinstance(uid, int): |
| | | d = next(iter(self.loader_dict.values())) |
| | | uid = list(d)[uid] |
| | | |
| | | if self.cache is not None and uid in self.cache: |
| | | data = self.cache[uid] |
| | | return uid, data |
| | | |
| | | data = {} |
| | | # 1. Load data from each loaders |
| | |
| | | if isinstance(value, (list, tuple)): |
| | | value = np.array(value) |
| | | if not isinstance( |
| | | value, (np.ndarray, torch.Tensor, str, numbers.Number) |
| | | value, (np.ndarray, torch.Tensor, str, numbers.Number) |
| | | ): |
| | | raise TypeError( |
| | | f"Must be ndarray, torch.Tensor, str or Number: {type(value)}" |
| | |
| | | raise NotImplementedError(f"Not supported dtype: {value.dtype}") |
| | | data[name] = value |
| | | |
| | | if self.cache is not None and self.cache.size < self.max_cache_size: |
| | | self.cache[uid] = data |
| | | |
| | | retval = uid, data |
| | | assert check_return_type(retval) |
| | | return retval |