| | |
| | | from typing import Iterator |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import List |
| | | |
| | | import kaldiio |
| | | import numpy as np |
| | | import soundfile |
| | | import torch |
| | | import torchaudio |
| | | import soundfile |
| | | from torch.utils.data.dataset import IterableDataset |
| | | from typeguard import check_argument_types |
| | | import os.path |
| | | |
| | | from funasr.datasets.dataset import ESPnetDataset |
| | | |
| | | |
| | | SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm'] |
| | | |
| | | def load_kaldi(input): |
| | | retval = kaldiio.load_mat(input) |
| | |
| | | return array |
| | | |
| | | |
| | | def load_bytes(input): |
| | | middle_data = np.frombuffer(input, dtype=np.int16) |
| | | middle_data = np.asarray(middle_data) |
| | | if middle_data.dtype.kind not in 'iu': |
| | | raise TypeError("'middle_data' must be an array of integers") |
| | | dtype = np.dtype('float32') |
| | | if dtype.kind != 'f': |
| | | raise TypeError("'dtype' must be a floating point type") |
| | | |
| | | i = np.iinfo(middle_data.dtype) |
| | | abs_max = 2 ** (i.bits - 1) |
| | | offset = i.min + abs_max |
| | | array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) |
| | | return array |
| | | |
| | | def load_pcm(input): |
| | | with open(input,"rb") as f: |
| | | bytes = f.read() |
| | | return load_bytes(bytes) |
| | | |
| | | def load_wav(input): |
| | | try: |
| | | return torchaudio.load(input)[0].numpy() |
| | | except: |
| | | waveform, _ = soundfile.read(input, dtype='float32') |
| | | if waveform.ndim == 2: |
| | | waveform = waveform[:, 0] |
| | | return np.expand_dims(waveform, axis=0) |
| | | |
| | | DATA_TYPES = { |
| | | "sound": lambda x: soundfile.read(x)[0], |
| | | "sound": load_wav, |
| | | "pcm": load_pcm, |
| | | "kaldi_ark": load_kaldi, |
| | | "bytes": load_bytes, |
| | | "waveform": lambda x: x, |
| | | "npy": np.load, |
| | | "text_int": lambda x: np.loadtxt( |
| | | StringIO(x), ndmin=1, dtype=np.long, delimiter=" " |
| | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | path_name_type_list: Collection[Tuple[str, str, str]], |
| | | preprocess: Callable[ |
| | | [str, Dict[str, np.ndarray]], Dict[str, np.ndarray] |
| | | ] = None, |
| | | float_dtype: str = "float32", |
| | | int_dtype: str = "long", |
| | | key_file: str = None, |
| | | self, |
| | | path_name_type_list: Collection[Tuple[any, str, str]], |
| | | preprocess: Callable[ |
| | | [str, Dict[str, np.ndarray]], Dict[str, np.ndarray] |
| | | ] = None, |
| | | float_dtype: str = "float32", |
| | | fs: dict = None, |
| | | mc: bool = False, |
| | | int_dtype: str = "long", |
| | | key_file: str = None, |
| | | ): |
| | | assert check_argument_types() |
| | | if len(path_name_type_list) == 0: |
| | | raise ValueError( |
| | | '1 or more elements are required for "path_name_type_list"' |
| | |
| | | self.float_dtype = float_dtype |
| | | self.int_dtype = int_dtype |
| | | self.key_file = key_file |
| | | self.fs = fs |
| | | self.mc = mc |
| | | |
| | | self.debug_info = {} |
| | | non_iterable_list = [] |
| | | self.path_name_type_list = [] |
| | | |
| | | for path, name, _type in path_name_type_list: |
| | | if name in self.debug_info: |
| | | raise RuntimeError(f'"{name}" is duplicated for data-key') |
| | | if not isinstance(path_name_type_list[0], (Tuple, List)): |
| | | path = path_name_type_list[0] |
| | | name = path_name_type_list[1] |
| | | _type = path_name_type_list[2] |
| | | self.debug_info[name] = path, _type |
| | | if _type not in DATA_TYPES: |
| | | non_iterable_list.append((path, name, _type)) |
| | | else: |
| | | self.path_name_type_list.append((path, name, _type)) |
| | | else: |
| | | for path, name, _type in path_name_type_list: |
| | | self.debug_info[name] = path, _type |
| | | if _type not in DATA_TYPES: |
| | | non_iterable_list.append((path, name, _type)) |
| | | else: |
| | | self.path_name_type_list.append((path, name, _type)) |
| | | |
| | | if len(non_iterable_list) != 0: |
| | | # Some types doesn't support iterable mode |
| | |
| | | else: |
| | | self.non_iterable_dataset = None |
| | | |
| | | if Path(Path(path_name_type_list[0][0]).parent, "utt2category").exists(): |
| | | self.apply_utt2category = True |
| | | else: |
| | | self.apply_utt2category = False |
| | | self.apply_utt2category = False |
| | | |
| | | def has_name(self, name) -> bool: |
| | | return name in self.debug_info |
| | |
| | | return _mes |
| | | |
| | | def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]: |
| | | if self.key_file is not None: |
| | | uid_iter = ( |
| | | line.rstrip().split(maxsplit=1)[0] |
| | | for line in open(self.key_file, encoding="utf-8") |
| | | ) |
| | | elif len(self.path_name_type_list) != 0: |
| | | uid_iter = ( |
| | | line.rstrip().split(maxsplit=1)[0] |
| | | for line in open(self.path_name_type_list[0][0], encoding="utf-8") |
| | | ) |
| | | else: |
| | | uid_iter = iter(self.non_iterable_dataset) |
| | | |
| | | files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list] |
| | | |
| | | worker_info = torch.utils.data.get_worker_info() |
| | | |
| | | linenum = 0 |
| | | count = 0 |
| | | for count, uid in enumerate(uid_iter, 1): |
| | | # If num_workers>=1, split keys |
| | | if worker_info is not None: |
| | | if (count - 1) % worker_info.num_workers != worker_info.id: |
| | | continue |
| | | |
| | | # 1. Read a line from each file |
| | | while True: |
| | | keys = [] |
| | | values = [] |
| | | for f in files: |
| | | linenum += 1 |
| | | try: |
| | | line = next(f) |
| | | except StopIteration: |
| | | raise RuntimeError(f"{uid} is not found in the files") |
| | | sps = line.rstrip().split(maxsplit=1) |
| | | if len(sps) != 2: |
| | | raise RuntimeError( |
| | | f"This line doesn't include a space:" |
| | | f" {f}:L{linenum}: {line})" |
| | | ) |
| | | key, value = sps |
| | | keys.append(key) |
| | | values.append(value) |
| | | |
| | | for k_idx, k in enumerate(keys): |
| | | if k != keys[0]: |
| | | raise RuntimeError( |
| | | f"Keys are mismatched. Text files (idx={k_idx}) is " |
| | | f"not sorted or not having same keys at L{linenum}" |
| | | ) |
| | | |
| | | # If the key is matched, break the loop |
| | | if len(keys) == 0 or keys[0] == uid: |
| | | break |
| | | |
| | | # 2. Load the entry from each line and create a dict |
| | | if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"): |
| | | linenum = len(self.path_name_type_list) |
| | | data = {} |
| | | # 2.a. Load data streamingly |
| | | for value, (path, name, _type) in zip(values, self.path_name_type_list): |
| | | for i in range(linenum): |
| | | value = self.path_name_type_list[i][0] |
| | | uid = 'utt_id' |
| | | name = self.path_name_type_list[i][1] |
| | | _type = self.path_name_type_list[i][2] |
| | | func = DATA_TYPES[_type] |
| | | # Load entry |
| | | array = func(value) |
| | | if self.fs is not None and (name == "speech" or name == "ref_speech"): |
| | | audio_fs = self.fs["audio_fs"] |
| | | model_fs = self.fs["model_fs"] |
| | | if audio_fs is not None and model_fs is not None: |
| | | array = torch.from_numpy(array) |
| | | array = array.unsqueeze(0) |
| | | array = torchaudio.transforms.Resample(orig_freq=audio_fs, |
| | | new_freq=model_fs)(array) |
| | | array = array.squeeze(0).numpy() |
| | | |
| | | data[name] = array |
| | | if self.non_iterable_dataset is not None: |
| | | # 2.b. Load data from non-iterable dataset |
| | | _, from_non_iterable = self.non_iterable_dataset[uid] |
| | | data.update(from_non_iterable) |
| | | |
| | | # 3. [Option] Apply preprocessing |
| | | # e.g. funasr.train.preprocessor:CommonPreprocessor |
| | | if self.preprocess is not None: |
| | | data = self.preprocess(uid, data) |
| | | |
| | | # 4. Force data-precision |
| | | for name in data: |
| | | value = data[name] |
| | | if not isinstance(value, np.ndarray): |
| | | raise RuntimeError( |
| | | f"All values must be converted to np.ndarray object " |
| | | f'by preprocessing, but "{name}" is still {type(value)}.' |
| | | ) |
| | | |
| | | # Cast to desired type |
| | | if value.dtype.kind == "f": |
| | | value = value.astype(self.float_dtype) |
| | | elif value.dtype.kind == "i": |
| | | value = value.astype(self.int_dtype) |
| | | else: |
| | | raise NotImplementedError(f"Not supported dtype: {value.dtype}") |
| | | data[name] = value |
| | | if self.preprocess is not None: |
| | | data = self.preprocess(uid, data) |
| | | for name in data: |
| | | count += 1 |
| | | value = data[name] |
| | | if not isinstance(value, np.ndarray): |
| | | raise RuntimeError( |
| | | f'All values must be converted to np.ndarray object ' |
| | | f'by preprocessing, but "{name}" is still {type(value)}.') |
| | | # Cast to desired type |
| | | if value.dtype.kind == 'f': |
| | | value = value.astype(self.float_dtype) |
| | | elif value.dtype.kind == 'i': |
| | | value = value.astype(self.int_dtype) |
| | | else: |
| | | raise NotImplementedError( |
| | | f'Not supported dtype: {value.dtype}') |
| | | data[name] = value |
| | | |
| | | yield uid, data |
| | | |
| | | elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"): |
| | | linenum = len(self.path_name_type_list) |
| | | data = {} |
| | | for i in range(linenum): |
| | | value = self.path_name_type_list[i][0] |
| | | uid = os.path.basename(self.path_name_type_list[i][0]).split(".")[0] |
| | | name = self.path_name_type_list[i][1] |
| | | _type = self.path_name_type_list[i][2] |
| | | if _type == "sound": |
| | | audio_type = os.path.basename(value).lower() |
| | | if audio_type.rfind(".pcm") >= 0: |
| | | _type = "pcm" |
| | | func = DATA_TYPES[_type] |
| | | array = func(value) |
| | | if self.fs is not None and (name == "speech" or name == "ref_speech"): |
| | | audio_fs = self.fs["audio_fs"] |
| | | model_fs = self.fs["model_fs"] |
| | | if audio_fs is not None and model_fs is not None: |
| | | array = torch.from_numpy(array) |
| | | array = torchaudio.transforms.Resample(orig_freq=audio_fs, |
| | | new_freq=model_fs)(array) |
| | | array = array.numpy() |
| | | |
| | | if _type == "sound": |
| | | if self.mc: |
| | | data[name] = array.transpose((1, 0)) |
| | | else: |
| | | data[name] = array[0] |
| | | else: |
| | | data[name] = array |
| | | |
| | | if self.preprocess is not None: |
| | | data = self.preprocess(uid, data) |
| | | for name in data: |
| | | count += 1 |
| | | value = data[name] |
| | | if not isinstance(value, np.ndarray): |
| | | raise RuntimeError( |
| | | f'All values must be converted to np.ndarray object ' |
| | | f'by preprocessing, but "{name}" is still {type(value)}.') |
| | | # Cast to desired type |
| | | if value.dtype.kind == 'f': |
| | | value = value.astype(self.float_dtype) |
| | | elif value.dtype.kind == 'i': |
| | | value = value.astype(self.int_dtype) |
| | | else: |
| | | raise NotImplementedError( |
| | | f'Not supported dtype: {value.dtype}') |
| | | data[name] = value |
| | | |
| | | yield uid, data |
| | | |
| | | else: |
| | | if self.key_file is not None: |
| | | uid_iter = ( |
| | | line.rstrip().split(maxsplit=1)[0] |
| | | for line in open(self.key_file, encoding="utf-8") |
| | | ) |
| | | elif len(self.path_name_type_list) != 0: |
| | | uid_iter = ( |
| | | line.rstrip().split(maxsplit=1)[0] |
| | | for line in open(self.path_name_type_list[0][0], encoding="utf-8") |
| | | ) |
| | | else: |
| | | uid_iter = iter(self.non_iterable_dataset) |
| | | |
| | | files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list] |
| | | |
| | | worker_info = torch.utils.data.get_worker_info() |
| | | |
| | | linenum = 0 |
| | | for count, uid in enumerate(uid_iter, 1): |
| | | # If num_workers>=1, split keys |
| | | if worker_info is not None: |
| | | if (count - 1) % worker_info.num_workers != worker_info.id: |
| | | continue |
| | | |
| | | # 1. Read a line from each file |
| | | while True: |
| | | keys = [] |
| | | values = [] |
| | | for f in files: |
| | | linenum += 1 |
| | | try: |
| | | line = next(f) |
| | | except StopIteration: |
| | | raise RuntimeError(f"{uid} is not found in the files") |
| | | sps = line.rstrip().split(maxsplit=1) |
| | | if len(sps) != 2: |
| | | raise RuntimeError( |
| | | f"This line doesn't include a space:" |
| | | f" {f}:L{linenum}: {line})" |
| | | ) |
| | | key, value = sps |
| | | keys.append(key) |
| | | values.append(value) |
| | | |
| | | for k_idx, k in enumerate(keys): |
| | | if k != keys[0]: |
| | | raise RuntimeError( |
| | | f"Keys are mismatched. Text files (idx={k_idx}) is " |
| | | f"not sorted or not having same keys at L{linenum}" |
| | | ) |
| | | |
| | | # If the key is matched, break the loop |
| | | if len(keys) == 0 or keys[0] == uid: |
| | | break |
| | | |
| | | # 2. Load the entry from each line and create a dict |
| | | data = {} |
| | | # 2.a. Load data streamingly |
| | | for value, (path, name, _type) in zip(values, self.path_name_type_list): |
| | | if _type == "sound": |
| | | audio_type = os.path.basename(value).lower() |
| | | if audio_type.rfind(".pcm") >= 0: |
| | | _type = "pcm" |
| | | func = DATA_TYPES[_type] |
| | | # Load entry |
| | | array = func(value) |
| | | if self.fs is not None and name == "speech": |
| | | audio_fs = self.fs["audio_fs"] |
| | | model_fs = self.fs["model_fs"] |
| | | if audio_fs is not None and model_fs is not None: |
| | | array = torch.from_numpy(array) |
| | | array = torchaudio.transforms.Resample(orig_freq=audio_fs, |
| | | new_freq=model_fs)(array) |
| | | array = array.numpy() |
| | | if _type == "sound": |
| | | if self.mc: |
| | | data[name] = array.transpose((1, 0)) |
| | | else: |
| | | data[name] = array[0] |
| | | else: |
| | | data[name] = array |
| | | if self.non_iterable_dataset is not None: |
| | | # 2.b. Load data from non-iterable dataset |
| | | _, from_non_iterable = self.non_iterable_dataset[uid] |
| | | data.update(from_non_iterable) |
| | | |
| | | # 3. [Option] Apply preprocessing |
| | | # e.g. funasr.train.preprocessor:CommonPreprocessor |
| | | if self.preprocess is not None: |
| | | data = self.preprocess(uid, data) |
| | | |
| | | # 4. Force data-precision |
| | | for name in data: |
| | | value = data[name] |
| | | if not isinstance(value, np.ndarray): |
| | | raise RuntimeError( |
| | | f"All values must be converted to np.ndarray object " |
| | | f'by preprocessing, but "{name}" is still {type(value)}.' |
| | | ) |
| | | |
| | | # Cast to desired type |
| | | if value.dtype.kind == "f": |
| | | value = value.astype(self.float_dtype) |
| | | elif value.dtype.kind == "i": |
| | | value = value.astype(self.int_dtype) |
| | | else: |
| | | raise NotImplementedError(f"Not supported dtype: {value.dtype}") |
| | | data[name] = value |
| | | |
| | | yield uid, data |
| | | |
| | | if count == 0: |
| | | raise RuntimeError("No iteration") |
| | | |