| | |
| | | from typing import Iterator |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import List |
| | | |
| | | import kaldiio |
| | | import numpy as np |
| | | import torch |
| | | import torchaudio |
| | | # import librosa |
| | | import librosa |
| | | from torch.utils.data.dataset import IterableDataset |
| | | from typeguard import check_argument_types |
| | | import os.path |
| | | |
| | | from funasr.datasets.dataset import ESPnetDataset |
| | |
| | | bytes = f.read() |
| | | return load_bytes(bytes) |
| | | |
| | | def load_wav(input): |
| | | try: |
| | | return torchaudio.load(input)[0].numpy() |
| | | except: |
| | | # waveform, _ = librosa.load(input, dtype='float32') |
| | | waveform, _ = librosa.load(input, dtype='float32') |
| | | if waveform.ndim == 2: |
| | | waveform = waveform[:, 0] |
| | | return np.expand_dims(waveform, axis=0) |
| | | |
| | | DATA_TYPES = { |
| | | "sound": lambda x: torchaudio.load(x)[0].numpy(), |
| | | "sound": load_wav, |
| | | "pcm": load_pcm, |
| | | "kaldi_ark": load_kaldi, |
| | | "bytes": load_bytes, |
| | |
| | | int_dtype: str = "long", |
| | | key_file: str = None, |
| | | ): |
| | | assert check_argument_types() |
| | | if len(path_name_type_list) == 0: |
| | | raise ValueError( |
| | | '1 or more elements are required for "path_name_type_list"' |
| | |
| | | non_iterable_list = [] |
| | | self.path_name_type_list = [] |
| | | |
| | | if not isinstance(path_name_type_list[0], Tuple): |
| | | if not isinstance(path_name_type_list[0], (Tuple, List)): |
| | | path = path_name_type_list[0] |
| | | name = path_name_type_list[1] |
| | | _type = path_name_type_list[2] |
| | |
| | | name = self.path_name_type_list[i][1] |
| | | _type = self.path_name_type_list[i][2] |
| | | if _type == "sound": |
| | | audio_type = os.path.basename(value).split(".")[-1].lower() |
| | | if audio_type not in SUPPORT_AUDIO_TYPE_SETS: |
| | | raise NotImplementedError( |
| | | f'Not supported audio type: {audio_type}') |
| | | if audio_type == "pcm": |
| | | _type = "pcm" |
| | | |
| | | audio_type = os.path.basename(value).lower() |
| | | if audio_type.rfind(".pcm") >= 0: |
| | | _type = "pcm" |
| | | func = DATA_TYPES[_type] |
| | | array = func(value) |
| | | if self.fs is not None and (name == "speech" or name == "ref_speech"): |
| | |
| | | # 2.a. Load data streamingly |
| | | for value, (path, name, _type) in zip(values, self.path_name_type_list): |
| | | if _type == "sound": |
| | | audio_type = os.path.basename(value).split(".")[-1].lower() |
| | | if audio_type not in SUPPORT_AUDIO_TYPE_SETS: |
| | | raise NotImplementedError( |
| | | f'Not supported audio type: {audio_type}') |
| | | if audio_type == "pcm": |
| | | audio_type = os.path.basename(value).lower() |
| | | if audio_type.rfind(".pcm") >= 0: |
| | | _type = "pcm" |
| | | func = DATA_TYPES[_type] |
| | | # Load entry |
| | |
| | | |
| | | if count == 0: |
| | | raise RuntimeError("No iteration") |
| | | |