游雁
2023-10-10 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3
funasr/datasets/iterable_dataset.py
@@ -8,16 +8,20 @@
from typing import Iterator
from typing import Tuple
from typing import Union
from typing import List
import kaldiio
import numpy as np
import soundfile
import torch
import torchaudio
import soundfile
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
from funasr.datasets.dataset import ESPnetDataset
SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
def load_kaldi(input):
    retval = kaldiio.load_mat(input)
@@ -42,9 +46,41 @@
    return array
def load_bytes(input):
    middle_data = np.frombuffer(input, dtype=np.int16)
    middle_data = np.asarray(middle_data)
    if middle_data.dtype.kind not in 'iu':
        raise TypeError("'middle_data' must be an array of integers")
    dtype = np.dtype('float32')
    if dtype.kind != 'f':
        raise TypeError("'dtype' must be a floating point type")
    i = np.iinfo(middle_data.dtype)
    abs_max = 2 ** (i.bits - 1)
    offset = i.min + abs_max
    array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
    return array
def load_pcm(input):
    with open(input,"rb") as f:
        bytes = f.read()
    return load_bytes(bytes)
def load_wav(input):
    try:
        return torchaudio.load(input)[0].numpy()
    except:
        waveform, _ = soundfile.read(input, dtype='float32')
        if waveform.ndim == 2:
            waveform = waveform[:, 0]
        return np.expand_dims(waveform, axis=0)
DATA_TYPES = {
    "sound": lambda x: soundfile.read(x)[0],
    "sound": load_wav,
    "pcm": load_pcm,
    "kaldi_ark": load_kaldi,
    "bytes": load_bytes,
    "waveform": lambda x: x,
    "npy": np.load,
    "text_int": lambda x: np.loadtxt(
        StringIO(x), ndmin=1, dtype=np.long, delimiter=" "
@@ -73,16 +109,17 @@
    """
    def __init__(
        self,
        path_name_type_list: Collection[Tuple[str, str, str]],
        preprocess: Callable[
            [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
        ] = None,
        float_dtype: str = "float32",
        int_dtype: str = "long",
        key_file: str = None,
            self,
            path_name_type_list: Collection[Tuple[any, str, str]],
            preprocess: Callable[
                [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
            ] = None,
            float_dtype: str = "float32",
            fs: dict = None,
            mc: bool = False,
            int_dtype: str = "long",
            key_file: str = None,
    ):
        assert check_argument_types()
        if len(path_name_type_list) == 0:
            raise ValueError(
                '1 or more elements are required for "path_name_type_list"'
@@ -94,19 +131,29 @@
        self.float_dtype = float_dtype
        self.int_dtype = int_dtype
        self.key_file = key_file
        self.fs = fs
        self.mc = mc
        self.debug_info = {}
        non_iterable_list = []
        self.path_name_type_list = []
        for path, name, _type in path_name_type_list:
            if name in self.debug_info:
                raise RuntimeError(f'"{name}" is duplicated for data-key')
        if not isinstance(path_name_type_list[0], (Tuple, List)):
            path = path_name_type_list[0]
            name = path_name_type_list[1]
            _type = path_name_type_list[2]
            self.debug_info[name] = path, _type
            if _type not in DATA_TYPES:
                non_iterable_list.append((path, name, _type))
            else:
                self.path_name_type_list.append((path, name, _type))
        else:
            for path, name, _type in path_name_type_list:
                self.debug_info[name] = path, _type
                if _type not in DATA_TYPES:
                    non_iterable_list.append((path, name, _type))
                else:
                    self.path_name_type_list.append((path, name, _type))
        if len(non_iterable_list) != 0:
            # Some types doesn't support iterable mode
@@ -119,10 +166,7 @@
        else:
            self.non_iterable_dataset = None
        if Path(Path(path_name_type_list[0][0]).parent, "utt2category").exists():
            self.apply_utt2category = True
        else:
            self.apply_utt2category = False
        self.apply_utt2category = False
    def has_name(self, name) -> bool:
        return name in self.debug_info
@@ -139,99 +183,214 @@
        return _mes
    def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
        if self.key_file is not None:
            uid_iter = (
                line.rstrip().split(maxsplit=1)[0]
                for line in open(self.key_file, encoding="utf-8")
            )
        elif len(self.path_name_type_list) != 0:
            uid_iter = (
                line.rstrip().split(maxsplit=1)[0]
                for line in open(self.path_name_type_list[0][0], encoding="utf-8")
            )
        else:
            uid_iter = iter(self.non_iterable_dataset)
        files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
        worker_info = torch.utils.data.get_worker_info()
        linenum = 0
        count = 0
        for count, uid in enumerate(uid_iter, 1):
            # If num_workers>=1, split keys
            if worker_info is not None:
                if (count - 1) % worker_info.num_workers != worker_info.id:
                    continue
            # 1. Read a line from each file
            while True:
                keys = []
                values = []
                for f in files:
                    linenum += 1
                    try:
                        line = next(f)
                    except StopIteration:
                        raise RuntimeError(f"{uid} is not found in the files")
                    sps = line.rstrip().split(maxsplit=1)
                    if len(sps) != 2:
                        raise RuntimeError(
                            f"This line doesn't include a space:"
                            f" {f}:L{linenum}: {line})"
                        )
                    key, value = sps
                    keys.append(key)
                    values.append(value)
                for k_idx, k in enumerate(keys):
                    if k != keys[0]:
                        raise RuntimeError(
                            f"Keys are mismatched. Text files (idx={k_idx}) is "
                            f"not sorted or not having same keys at L{linenum}"
                        )
                # If the key is matched, break the loop
                if len(keys) == 0 or keys[0] == uid:
                    break
            # 2. Load the entry from each line and create a dict
        if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"):
            linenum = len(self.path_name_type_list)
            data = {}
            # 2.a. Load data streamingly
            for value, (path, name, _type) in zip(values, self.path_name_type_list):
            for i in range(linenum):
                value = self.path_name_type_list[i][0]
                uid = 'utt_id'
                name = self.path_name_type_list[i][1]
                _type = self.path_name_type_list[i][2]
                func = DATA_TYPES[_type]
                # Load entry
                array = func(value)
                if self.fs is not None and (name == "speech" or name == "ref_speech"):
                    audio_fs = self.fs["audio_fs"]
                    model_fs = self.fs["model_fs"]
                    if audio_fs is not None and model_fs is not None:
                        array = torch.from_numpy(array)
                        array = array.unsqueeze(0)
                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                       new_freq=model_fs)(array)
                        array = array.squeeze(0).numpy()
                data[name] = array
            if self.non_iterable_dataset is not None:
                # 2.b. Load data from non-iterable dataset
                _, from_non_iterable = self.non_iterable_dataset[uid]
                data.update(from_non_iterable)
            # 3. [Option] Apply preprocessing
            #   e.g. funasr.train.preprocessor:CommonPreprocessor
            if self.preprocess is not None:
                data = self.preprocess(uid, data)
            # 4. Force data-precision
            for name in data:
                value = data[name]
                if not isinstance(value, np.ndarray):
                    raise RuntimeError(
                        f"All values must be converted to np.ndarray object "
                        f'by preprocessing, but "{name}" is still {type(value)}.'
                    )
                # Cast to desired type
                if value.dtype.kind == "f":
                    value = value.astype(self.float_dtype)
                elif value.dtype.kind == "i":
                    value = value.astype(self.int_dtype)
                else:
                    raise NotImplementedError(f"Not supported dtype: {value.dtype}")
                data[name] = value
                if self.preprocess is not None:
                    data = self.preprocess(uid, data)
                for name in data:
                    count += 1
                    value = data[name]
                    if not isinstance(value, np.ndarray):
                        raise RuntimeError(
                            f'All values must be converted to np.ndarray object '
                            f'by preprocessing, but "{name}" is still {type(value)}.')
                    # Cast to desired type
                    if value.dtype.kind == 'f':
                        value = value.astype(self.float_dtype)
                    elif value.dtype.kind == 'i':
                        value = value.astype(self.int_dtype)
                    else:
                        raise NotImplementedError(
                            f'Not supported dtype: {value.dtype}')
                    data[name] = value
            yield uid, data
        elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"):
            linenum = len(self.path_name_type_list)
            data = {}
            for i in range(linenum):
                value = self.path_name_type_list[i][0]
                uid = os.path.basename(self.path_name_type_list[i][0]).split(".")[0]
                name = self.path_name_type_list[i][1]
                _type = self.path_name_type_list[i][2]
                if _type == "sound":
                   audio_type = os.path.basename(value).lower()
                   if audio_type.rfind(".pcm") >= 0:
                       _type = "pcm"
                func = DATA_TYPES[_type]
                array = func(value)
                if self.fs is not None and (name == "speech" or name == "ref_speech"):
                    audio_fs = self.fs["audio_fs"]
                    model_fs = self.fs["model_fs"]
                    if audio_fs is not None and model_fs is not None:
                        array = torch.from_numpy(array)
                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                               new_freq=model_fs)(array)
                        array = array.numpy()
                if _type == "sound":
                    if self.mc:
                        data[name] = array.transpose((1, 0))
                    else:
                        data[name] = array[0]
                else:
                    data[name] = array
                if self.preprocess is not None:
                    data = self.preprocess(uid, data)
                for name in data:
                    count += 1
                    value = data[name]
                    if not isinstance(value, np.ndarray):
                        raise RuntimeError(
                            f'All values must be converted to np.ndarray object '
                            f'by preprocessing, but "{name}" is still {type(value)}.')
                    # Cast to desired type
                    if value.dtype.kind == 'f':
                        value = value.astype(self.float_dtype)
                    elif value.dtype.kind == 'i':
                        value = value.astype(self.int_dtype)
                    else:
                        raise NotImplementedError(
                            f'Not supported dtype: {value.dtype}')
                    data[name] = value
            yield uid, data
        else:
            if self.key_file is not None:
                uid_iter = (
                    line.rstrip().split(maxsplit=1)[0]
                    for line in open(self.key_file, encoding="utf-8")
                )
            elif len(self.path_name_type_list) != 0:
                uid_iter = (
                    line.rstrip().split(maxsplit=1)[0]
                    for line in open(self.path_name_type_list[0][0], encoding="utf-8")
                )
            else:
                uid_iter = iter(self.non_iterable_dataset)
            files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
            worker_info = torch.utils.data.get_worker_info()
            linenum = 0
            for count, uid in enumerate(uid_iter, 1):
                # If num_workers>=1, split keys
                if worker_info is not None:
                    if (count - 1) % worker_info.num_workers != worker_info.id:
                        continue
                # 1. Read a line from each file
                while True:
                    keys = []
                    values = []
                    for f in files:
                        linenum += 1
                        try:
                            line = next(f)
                        except StopIteration:
                            raise RuntimeError(f"{uid} is not found in the files")
                        sps = line.rstrip().split(maxsplit=1)
                        if len(sps) != 2:
                            raise RuntimeError(
                                f"This line doesn't include a space:"
                                f" {f}:L{linenum}: {line})"
                            )
                        key, value = sps
                        keys.append(key)
                        values.append(value)
                    for k_idx, k in enumerate(keys):
                        if k != keys[0]:
                            raise RuntimeError(
                                f"Keys are mismatched. Text files (idx={k_idx}) is "
                                f"not sorted or not having same keys at L{linenum}"
                            )
                    # If the key is matched, break the loop
                    if len(keys) == 0 or keys[0] == uid:
                        break
                # 2. Load the entry from each line and create a dict
                data = {}
                # 2.a. Load data streamingly
                for value, (path, name, _type) in zip(values, self.path_name_type_list):
                    if _type == "sound":
                        audio_type = os.path.basename(value).lower()
                        if audio_type.rfind(".pcm") >= 0:
                            _type = "pcm"
                    func = DATA_TYPES[_type]
                    # Load entry
                    array = func(value)
                    if self.fs is not None and name == "speech":
                        audio_fs = self.fs["audio_fs"]
                        model_fs = self.fs["model_fs"]
                        if audio_fs is not None and model_fs is not None:
                            array = torch.from_numpy(array)
                            array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                                   new_freq=model_fs)(array)
                            array = array.numpy()
                    if _type == "sound":
                        if self.mc:
                            data[name] = array.transpose((1, 0))
                        else:
                            data[name] = array[0]
                    else:
                        data[name] = array
                if self.non_iterable_dataset is not None:
                    # 2.b. Load data from non-iterable dataset
                    _, from_non_iterable = self.non_iterable_dataset[uid]
                    data.update(from_non_iterable)
                # 3. [Option] Apply preprocessing
                #   e.g. funasr.train.preprocessor:CommonPreprocessor
                if self.preprocess is not None:
                    data = self.preprocess(uid, data)
                # 4. Force data-precision
                for name in data:
                    value = data[name]
                    if not isinstance(value, np.ndarray):
                        raise RuntimeError(
                            f"All values must be converted to np.ndarray object "
                            f'by preprocessing, but "{name}" is still {type(value)}.'
                        )
                    # Cast to desired type
                    if value.dtype.kind == "f":
                        value = value.astype(self.float_dtype)
                    elif value.dtype.kind == "i":
                        value = value.astype(self.int_dtype)
                    else:
                        raise NotImplementedError(f"Not supported dtype: {value.dtype}")
                    data[name] = value
                yield uid, data
        if count == 0:
            raise RuntimeError("No iteration")