游雁
2023-02-13 fcc9c89eaba9a4e36c54958aeedeec7ab3756cd7
funasr/datasets/iterable_dataset.py
@@ -11,13 +11,16 @@
import kaldiio
import numpy as np
import soundfile
import torch
import torchaudio
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
from funasr.datasets.dataset import ESPnetDataset
SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
def load_kaldi(input):
    retval = kaldiio.load_mat(input)
@@ -42,9 +45,32 @@
    return array
def load_bytes(input):
    middle_data = np.frombuffer(input, dtype=np.int16)
    middle_data = np.asarray(middle_data)
    if middle_data.dtype.kind not in 'iu':
        raise TypeError("'middle_data' must be an array of integers")
    dtype = np.dtype('float32')
    if dtype.kind != 'f':
        raise TypeError("'dtype' must be a floating point type")
    i = np.iinfo(middle_data.dtype)
    abs_max = 2 ** (i.bits - 1)
    offset = i.min + abs_max
    array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
    return array
def load_pcm(input):
    with open(input,"rb") as f:
        bytes = f.read()
    return load_bytes(bytes)
DATA_TYPES = {
    "sound": lambda x: soundfile.read(x)[0],
    "sound": lambda x: torchaudio.load(x)[0][0].numpy(),
    "pcm": load_pcm,
    "kaldi_ark": load_kaldi,
    "bytes": load_bytes,
    "waveform": lambda x: x,
    "npy": np.load,
    "text_int": lambda x: np.loadtxt(
        StringIO(x), ndmin=1, dtype=np.long, delimiter=" "
@@ -73,14 +99,15 @@
    """
    def __init__(
        self,
        path_name_type_list: Collection[Tuple[str, str, str]],
        preprocess: Callable[
            [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
        ] = None,
        float_dtype: str = "float32",
        int_dtype: str = "long",
        key_file: str = None,
            self,
            path_name_type_list: Collection[Tuple[any, str, str]],
            preprocess: Callable[
                [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
            ] = None,
            float_dtype: str = "float32",
            fs: dict = None,
            int_dtype: str = "long",
            key_file: str = None,
    ):
        assert check_argument_types()
        if len(path_name_type_list) == 0:
@@ -94,19 +121,28 @@
        self.float_dtype = float_dtype
        self.int_dtype = int_dtype
        self.key_file = key_file
        self.fs = fs
        self.debug_info = {}
        non_iterable_list = []
        self.path_name_type_list = []
        for path, name, _type in path_name_type_list:
            if name in self.debug_info:
                raise RuntimeError(f'"{name}" is duplicated for data-key')
        if not isinstance(path_name_type_list[0], Tuple):
            path = path_name_type_list[0]
            name = path_name_type_list[1]
            _type = path_name_type_list[2]
            self.debug_info[name] = path, _type
            if _type not in DATA_TYPES:
                non_iterable_list.append((path, name, _type))
            else:
                self.path_name_type_list.append((path, name, _type))
        else:
            for path, name, _type in path_name_type_list:
                self.debug_info[name] = path, _type
                if _type not in DATA_TYPES:
                    non_iterable_list.append((path, name, _type))
                else:
                    self.path_name_type_list.append((path, name, _type))
        if len(non_iterable_list) != 0:
            # Some types doesn't support iterable mode
@@ -119,10 +155,7 @@
        else:
            self.non_iterable_dataset = None
        if Path(Path(path_name_type_list[0][0]).parent, "utt2category").exists():
            self.apply_utt2category = True
        else:
            self.apply_utt2category = False
        self.apply_utt2category = False
    def has_name(self, name) -> bool:
        return name in self.debug_info
@@ -139,99 +172,204 @@
        return _mes
    def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
        if self.key_file is not None:
            uid_iter = (
                line.rstrip().split(maxsplit=1)[0]
                for line in open(self.key_file, encoding="utf-8")
            )
        elif len(self.path_name_type_list) != 0:
            uid_iter = (
                line.rstrip().split(maxsplit=1)[0]
                for line in open(self.path_name_type_list[0][0], encoding="utf-8")
            )
        else:
            uid_iter = iter(self.non_iterable_dataset)
        files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
        worker_info = torch.utils.data.get_worker_info()
        linenum = 0
        count = 0
        for count, uid in enumerate(uid_iter, 1):
            # If num_workers>=1, split keys
            if worker_info is not None:
                if (count - 1) % worker_info.num_workers != worker_info.id:
                    continue
            # 1. Read a line from each file
            while True:
                keys = []
                values = []
                for f in files:
                    linenum += 1
                    try:
                        line = next(f)
                    except StopIteration:
                        raise RuntimeError(f"{uid} is not found in the files")
                    sps = line.rstrip().split(maxsplit=1)
                    if len(sps) != 2:
                        raise RuntimeError(
                            f"This line doesn't include a space:"
                            f" {f}:L{linenum}: {line})"
                        )
                    key, value = sps
                    keys.append(key)
                    values.append(value)
                for k_idx, k in enumerate(keys):
                    if k != keys[0]:
                        raise RuntimeError(
                            f"Keys are mismatched. Text files (idx={k_idx}) is "
                            f"not sorted or not having same keys at L{linenum}"
                        )
                # If the key is matched, break the loop
                if len(keys) == 0 or keys[0] == uid:
                    break
            # 2. Load the entry from each line and create a dict
        if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"):
            data = {}
            # 2.a. Load data streamingly
            for value, (path, name, _type) in zip(values, self.path_name_type_list):
                func = DATA_TYPES[_type]
                # Load entry
                array = func(value)
                data[name] = array
            if self.non_iterable_dataset is not None:
                # 2.b. Load data from non-iterable dataset
                _, from_non_iterable = self.non_iterable_dataset[uid]
                data.update(from_non_iterable)
            value = self.path_name_type_list[0][0]
            uid = 'utt_id'
            name = self.path_name_type_list[0][1]
            _type = self.path_name_type_list[0][2]
            func = DATA_TYPES[_type]
            array = func(value)
            if self.fs is not None and name == "speech":
                audio_fs = self.fs["audio_fs"]
                model_fs = self.fs["model_fs"]
                if audio_fs is not None and model_fs is not None:
                    array = torch.from_numpy(array)
                    array = array.unsqueeze(0)
                    array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                   new_freq=model_fs)(array)
                    array = array.squeeze(0).numpy()
            data[name] = array
            # 3. [Option] Apply preprocessing
            #   e.g. funasr.train.preprocessor:CommonPreprocessor
            if self.preprocess is not None:
                data = self.preprocess(uid, data)
            # 4. Force data-precision
            for name in data:
                count += 1
                value = data[name]
                if not isinstance(value, np.ndarray):
                    raise RuntimeError(
                        f"All values must be converted to np.ndarray object "
                        f'by preprocessing, but "{name}" is still {type(value)}.'
                    )
                        f'All values must be converted to np.ndarray object '
                        f'by preprocessing, but "{name}" is still {type(value)}.')
                # Cast to desired type
                if value.dtype.kind == "f":
                if value.dtype.kind == 'f':
                    value = value.astype(self.float_dtype)
                elif value.dtype.kind == "i":
                elif value.dtype.kind == 'i':
                    value = value.astype(self.int_dtype)
                else:
                    raise NotImplementedError(f"Not supported dtype: {value.dtype}")
                    raise NotImplementedError(
                        f'Not supported dtype: {value.dtype}')
                data[name] = value
            yield uid, data
        elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"):
            data = {}
            value = self.path_name_type_list[0][0]
            uid = os.path.basename(self.path_name_type_list[0][0]).split(".")[0]
            name = self.path_name_type_list[0][1]
            _type = self.path_name_type_list[0][2]
            if _type == "sound":
                audio_type = os.path.basename(value).split(".")[1].lower()
                if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
                    raise NotImplementedError(
                        f'Not supported audio type: {audio_type}')
                if audio_type == "pcm":
                    _type = "pcm"
            func = DATA_TYPES[_type]
            array = func(value)
            if self.fs is not None and name == "speech":
                audio_fs = self.fs["audio_fs"]
                model_fs = self.fs["model_fs"]
                if audio_fs is not None and model_fs is not None:
                    array = torch.from_numpy(array)
                    array = array.unsqueeze(0)
                    array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                           new_freq=model_fs)(array)
                    array = array.squeeze(0).numpy()
            data[name] = array
            if self.preprocess is not None:
                data = self.preprocess(uid, data)
            for name in data:
                count += 1
                value = data[name]
                if not isinstance(value, np.ndarray):
                    raise RuntimeError(
                        f'All values must be converted to np.ndarray object '
                        f'by preprocessing, but "{name}" is still {type(value)}.')
                # Cast to desired type
                if value.dtype.kind == 'f':
                    value = value.astype(self.float_dtype)
                elif value.dtype.kind == 'i':
                    value = value.astype(self.int_dtype)
                else:
                    raise NotImplementedError(
                        f'Not supported dtype: {value.dtype}')
                data[name] = value
            yield uid, data
        else:
            if self.key_file is not None:
                uid_iter = (
                    line.rstrip().split(maxsplit=1)[0]
                    for line in open(self.key_file, encoding="utf-8")
                )
            elif len(self.path_name_type_list) != 0:
                uid_iter = (
                    line.rstrip().split(maxsplit=1)[0]
                    for line in open(self.path_name_type_list[0][0], encoding="utf-8")
                )
            else:
                uid_iter = iter(self.non_iterable_dataset)
            files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
            worker_info = torch.utils.data.get_worker_info()
            linenum = 0
            for count, uid in enumerate(uid_iter, 1):
                # If num_workers>=1, split keys
                if worker_info is not None:
                    if (count - 1) % worker_info.num_workers != worker_info.id:
                        continue
                # 1. Read a line from each file
                while True:
                    keys = []
                    values = []
                    for f in files:
                        linenum += 1
                        try:
                            line = next(f)
                        except StopIteration:
                            raise RuntimeError(f"{uid} is not found in the files")
                        sps = line.rstrip().split(maxsplit=1)
                        if len(sps) != 2:
                            raise RuntimeError(
                                f"This line doesn't include a space:"
                                f" {f}:L{linenum}: {line})"
                            )
                        key, value = sps
                        keys.append(key)
                        values.append(value)
                    for k_idx, k in enumerate(keys):
                        if k != keys[0]:
                            raise RuntimeError(
                                f"Keys are mismatched. Text files (idx={k_idx}) is "
                                f"not sorted or not having same keys at L{linenum}"
                            )
                    # If the key is matched, break the loop
                    if len(keys) == 0 or keys[0] == uid:
                        break
                # 2. Load the entry from each line and create a dict
                data = {}
                # 2.a. Load data streamingly
                for value, (path, name, _type) in zip(values, self.path_name_type_list):
                    if _type == "sound":
                        audio_type = os.path.basename(value).split(".")[1].lower()
                        if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
                            raise NotImplementedError(
                                f'Not supported audio type: {audio_type}')
                        if audio_type == "pcm":
                            _type = "pcm"
                    func = DATA_TYPES[_type]
                    # Load entry
                    array = func(value)
                    if self.fs is not None and name == "speech":
                        audio_fs = self.fs["audio_fs"]
                        model_fs = self.fs["model_fs"]
                        if audio_fs is not None and model_fs is not None:
                            array = torch.from_numpy(array)
                            array = array.unsqueeze(0)
                            array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                                   new_freq=model_fs)(array)
                            array = array.squeeze(0).numpy()
                    data[name] = array
                if self.non_iterable_dataset is not None:
                    # 2.b. Load data from non-iterable dataset
                    _, from_non_iterable = self.non_iterable_dataset[uid]
                    data.update(from_non_iterable)
                # 3. [Option] Apply preprocessing
                #   e.g. funasr.train.preprocessor:CommonPreprocessor
                if self.preprocess is not None:
                    data = self.preprocess(uid, data)
                # 4. Force data-precision
                for name in data:
                    value = data[name]
                    if not isinstance(value, np.ndarray):
                        raise RuntimeError(
                            f"All values must be converted to np.ndarray object "
                            f'by preprocessing, but "{name}" is still {type(value)}.'
                        )
                    # Cast to desired type
                    if value.dtype.kind == "f":
                        value = value.astype(self.float_dtype)
                    elif value.dtype.kind == "i":
                        value = value.astype(self.int_dtype)
                    else:
                        raise NotImplementedError(f"Not supported dtype: {value.dtype}")
                    data[name] = value
                yield uid, data
        if count == 0:
            raise RuntimeError("No iteration")