jmwang66
2023-01-16 12a7adfdf3dd4f80b5d3a51cfc4eecc84eaa7c64
funasr/datasets/iterable_dataset.py
@@ -15,6 +15,7 @@
import torch
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
from funasr.datasets.dataset import ESPnetDataset
@@ -42,9 +43,27 @@
    return array
def load_bytes(input):
    middle_data = np.frombuffer(input, dtype=np.int16)
    middle_data = np.asarray(middle_data)
    if middle_data.dtype.kind not in 'iu':
        raise TypeError("'middle_data' must be an array of integers")
    dtype = np.dtype('float32')
    if dtype.kind != 'f':
        raise TypeError("'dtype' must be a floating point type")
    i = np.iinfo(middle_data.dtype)
    abs_max = 2 ** (i.bits - 1)
    offset = i.min + abs_max
    array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
    return array
DATA_TYPES = {
    "sound": lambda x: soundfile.read(x)[0],
    "kaldi_ark": load_kaldi,
    "bytes": load_bytes,
    "waveform": lambda x: x,
    "npy": np.load,
    "text_int": lambda x: np.loadtxt(
        StringIO(x), ndmin=1, dtype=np.long, delimiter=" "
@@ -74,7 +93,7 @@
    def __init__(
        self,
        path_name_type_list: Collection[Tuple[str, str, str]],
            path_name_type_list: Collection[Tuple[any, str, str]],
        preprocess: Callable[
            [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
        ] = None,
@@ -99,9 +118,17 @@
        non_iterable_list = []
        self.path_name_type_list = []
        if not isinstance(path_name_type_list[0], Tuple):
            path = path_name_type_list[0]
            name = path_name_type_list[1]
            _type = path_name_type_list[2]
            self.debug_info[name] = path, _type
            if _type not in DATA_TYPES:
                non_iterable_list.append((path, name, _type))
            else:
                self.path_name_type_list.append((path, name, _type))
        else:
        for path, name, _type in path_name_type_list:
            if name in self.debug_info:
                raise RuntimeError(f'"{name}" is duplicated for data-key')
            self.debug_info[name] = path, _type
            if _type not in DATA_TYPES:
                non_iterable_list.append((path, name, _type))
@@ -119,9 +146,6 @@
        else:
            self.non_iterable_dataset = None
        if Path(Path(path_name_type_list[0][0]).parent, "utt2category").exists():
            self.apply_utt2category = True
        else:
            self.apply_utt2category = False
    def has_name(self, name) -> bool:
@@ -139,6 +163,70 @@
        return _mes
    def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
        count = 0
        if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"):
            data = {}
            value = self.path_name_type_list[0][0]
            uid = 'utt_id'
            name = self.path_name_type_list[0][1]
            _type = self.path_name_type_list[0][2]
            func = DATA_TYPES[_type]
            array = func(value)
            data[name] = array
            if self.preprocess is not None:
                data = self.preprocess(uid, data)
            for name in data:
                count += 1
                value = data[name]
                if not isinstance(value, np.ndarray):
                    raise RuntimeError(
                        f'All values must be converted to np.ndarray object '
                        f'by preprocessing, but "{name}" is still {type(value)}.')
                # Cast to desired type
                if value.dtype.kind == 'f':
                    value = value.astype(self.float_dtype)
                elif value.dtype.kind == 'i':
                    value = value.astype(self.int_dtype)
                else:
                    raise NotImplementedError(
                        f'Not supported dtype: {value.dtype}')
                data[name] = value
            yield uid, data
        elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"):
            data = {}
            value = self.path_name_type_list[0][0]
            uid = os.path.basename(self.path_name_type_list[0][0]).split(".")[0]
            name = self.path_name_type_list[0][1]
            _type = self.path_name_type_list[0][2]
            func = DATA_TYPES[_type]
            array = func(value)
            data[name] = array
            if self.preprocess is not None:
                data = self.preprocess(uid, data)
            for name in data:
                count += 1
                value = data[name]
                if not isinstance(value, np.ndarray):
                    raise RuntimeError(
                        f'All values must be converted to np.ndarray object '
                        f'by preprocessing, but "{name}" is still {type(value)}.')
                # Cast to desired type
                if value.dtype.kind == 'f':
                    value = value.astype(self.float_dtype)
                elif value.dtype.kind == 'i':
                    value = value.astype(self.int_dtype)
                else:
                    raise NotImplementedError(
                        f'Not supported dtype: {value.dtype}')
                data[name] = value
            yield uid, data
        else:
        if self.key_file is not None:
            uid_iter = (
                line.rstrip().split(maxsplit=1)[0]
@@ -157,7 +245,6 @@
        worker_info = torch.utils.data.get_worker_info()
        linenum = 0
        count = 0
        for count, uid in enumerate(uid_iter, 1):
            # If num_workers>=1, split keys
            if worker_info is not None: