游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/datasets/iterable_dataset.py
@@ -8,17 +8,20 @@
from typing import Iterator
from typing import Tuple
from typing import Union
from typing import List
import kaldiio
import numpy as np
import soundfile
import torch
import torchaudio
import soundfile
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
from funasr.datasets.dataset import ESPnetDataset
SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
def load_kaldi(input):
    retval = kaldiio.load_mat(input)
@@ -58,9 +61,23 @@
    array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
    return array
def load_pcm(input):
    with open(input,"rb") as f:
        bytes = f.read()
    return load_bytes(bytes)
def load_wav(input):
    try:
        return torchaudio.load(input)[0].numpy()
    except:
        waveform, _ = soundfile.read(input, dtype='float32')
        if waveform.ndim == 2:
            waveform = waveform[:, 0]
        return np.expand_dims(waveform, axis=0)
DATA_TYPES = {
    "sound": lambda x: soundfile.read(x)[0],
    "sound": load_wav,
    "pcm": load_pcm,
    "kaldi_ark": load_kaldi,
    "bytes": load_bytes,
    "waveform": lambda x: x,
@@ -98,10 +115,11 @@
                [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
            ] = None,
            float_dtype: str = "float32",
            fs: dict = None,
            mc: bool = False,
            int_dtype: str = "long",
            key_file: str = None,
    ):
        assert check_argument_types()
        if len(path_name_type_list) == 0:
            raise ValueError(
                '1 or more elements are required for "path_name_type_list"'
@@ -113,12 +131,14 @@
        self.float_dtype = float_dtype
        self.int_dtype = int_dtype
        self.key_file = key_file
        self.fs = fs
        self.mc = mc
        self.debug_info = {}
        non_iterable_list = []
        self.path_name_type_list = []
        if not isinstance(path_name_type_list[0], Tuple):
        if not isinstance(path_name_type_list[0], (Tuple, List)):
            path = path_name_type_list[0]
            name = path_name_type_list[1]
            _type = path_name_type_list[2]
@@ -165,64 +185,97 @@
    def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
        count = 0
        if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"):
            linenum = len(self.path_name_type_list)
            data = {}
            value = self.path_name_type_list[0][0]
            uid = 'utt_id'
            name = self.path_name_type_list[0][1]
            _type = self.path_name_type_list[0][2]
            func = DATA_TYPES[_type]
            array = func(value)
            data[name] = array
            for i in range(linenum):
                value = self.path_name_type_list[i][0]
                uid = 'utt_id'
                name = self.path_name_type_list[i][1]
                _type = self.path_name_type_list[i][2]
                func = DATA_TYPES[_type]
                array = func(value)
                if self.fs is not None and (name == "speech" or name == "ref_speech"):
                    audio_fs = self.fs["audio_fs"]
                    model_fs = self.fs["model_fs"]
                    if audio_fs is not None and model_fs is not None:
                        array = torch.from_numpy(array)
                        array = array.unsqueeze(0)
                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                       new_freq=model_fs)(array)
                        array = array.squeeze(0).numpy()
            if self.preprocess is not None:
                data = self.preprocess(uid, data)
            for name in data:
                count += 1
                value = data[name]
                if not isinstance(value, np.ndarray):
                    raise RuntimeError(
                        f'All values must be converted to np.ndarray object '
                        f'by preprocessing, but "{name}" is still {type(value)}.')
                # Cast to desired type
                if value.dtype.kind == 'f':
                    value = value.astype(self.float_dtype)
                elif value.dtype.kind == 'i':
                    value = value.astype(self.int_dtype)
                else:
                    raise NotImplementedError(
                        f'Not supported dtype: {value.dtype}')
                data[name] = value
                data[name] = array
                if self.preprocess is not None:
                    data = self.preprocess(uid, data)
                for name in data:
                    count += 1
                    value = data[name]
                    if not isinstance(value, np.ndarray):
                        raise RuntimeError(
                            f'All values must be converted to np.ndarray object '
                            f'by preprocessing, but "{name}" is still {type(value)}.')
                    # Cast to desired type
                    if value.dtype.kind == 'f':
                        value = value.astype(self.float_dtype)
                    elif value.dtype.kind == 'i':
                        value = value.astype(self.int_dtype)
                    else:
                        raise NotImplementedError(
                            f'Not supported dtype: {value.dtype}')
                    data[name] = value
            yield uid, data
        elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"):
            linenum = len(self.path_name_type_list)
            data = {}
            value = self.path_name_type_list[0][0]
            uid = os.path.basename(self.path_name_type_list[0][0]).split(".")[0]
            name = self.path_name_type_list[0][1]
            _type = self.path_name_type_list[0][2]
            func = DATA_TYPES[_type]
            array = func(value)
            data[name] = array
            if self.preprocess is not None:
                data = self.preprocess(uid, data)
            for name in data:
                count += 1
                value = data[name]
                if not isinstance(value, np.ndarray):
                    raise RuntimeError(
                        f'All values must be converted to np.ndarray object '
                        f'by preprocessing, but "{name}" is still {type(value)}.')
                # Cast to desired type
                if value.dtype.kind == 'f':
                    value = value.astype(self.float_dtype)
                elif value.dtype.kind == 'i':
                    value = value.astype(self.int_dtype)
            for i in range(linenum):
                value = self.path_name_type_list[i][0]
                uid = os.path.basename(self.path_name_type_list[i][0]).split(".")[0]
                name = self.path_name_type_list[i][1]
                _type = self.path_name_type_list[i][2]
                if _type == "sound":
                   audio_type = os.path.basename(value).lower()
                   if audio_type.rfind(".pcm") >= 0:
                       _type = "pcm"
                func = DATA_TYPES[_type]
                array = func(value)
                if self.fs is not None and (name == "speech" or name == "ref_speech"):
                    audio_fs = self.fs["audio_fs"]
                    model_fs = self.fs["model_fs"]
                    if audio_fs is not None and model_fs is not None:
                        array = torch.from_numpy(array)
                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                               new_freq=model_fs)(array)
                        array = array.numpy()
                if _type == "sound":
                    if self.mc:
                        data[name] = array.transpose((1, 0))
                    else:
                        data[name] = array[0]
                else:
                    raise NotImplementedError(
                        f'Not supported dtype: {value.dtype}')
                data[name] = value
                    data[name] = array
                if self.preprocess is not None:
                    data = self.preprocess(uid, data)
                for name in data:
                    count += 1
                    value = data[name]
                    if not isinstance(value, np.ndarray):
                        raise RuntimeError(
                            f'All values must be converted to np.ndarray object '
                            f'by preprocessing, but "{name}" is still {type(value)}.')
                    # Cast to desired type
                    if value.dtype.kind == 'f':
                        value = value.astype(self.float_dtype)
                    elif value.dtype.kind == 'i':
                        value = value.astype(self.int_dtype)
                    else:
                        raise NotImplementedError(
                            f'Not supported dtype: {value.dtype}')
                    data[name] = value
            yield uid, data
@@ -286,10 +339,28 @@
                data = {}
                # 2.a. Load data streamingly
                for value, (path, name, _type) in zip(values, self.path_name_type_list):
                    if _type == "sound":
                        audio_type = os.path.basename(value).lower()
                        if audio_type.rfind(".pcm") >= 0:
                            _type = "pcm"
                    func = DATA_TYPES[_type]
                    # Load entry
                    array = func(value)
                    data[name] = array
                    if self.fs is not None and name == "speech":
                        audio_fs = self.fs["audio_fs"]
                        model_fs = self.fs["model_fs"]
                        if audio_fs is not None and model_fs is not None:
                            array = torch.from_numpy(array)
                            array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                                   new_freq=model_fs)(array)
                            array = array.numpy()
                    if _type == "sound":
                        if self.mc:
                            data[name] = array.transpose((1, 0))
                        else:
                            data[name] = array[0]
                    else:
                        data[name] = array
                if self.non_iterable_dataset is not None:
                    # 2.b. Load data from non-iterable dataset
                    _, from_non_iterable = self.non_iterable_dataset[uid]
@@ -322,3 +393,4 @@
        if count == 0:
            raise RuntimeError("No iteration")