游雁
2023-10-10 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3
funasr/datasets/iterable_dataset.py
@@ -8,13 +8,14 @@
from typing import Iterator
from typing import Tuple
from typing import Union
from typing import List
import kaldiio
import numpy as np
import torch
import torchaudio
import soundfile
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
from funasr.datasets.dataset import ESPnetDataset
@@ -65,8 +66,17 @@
        bytes = f.read()
    return load_bytes(bytes)
def load_wav(input):
    try:
        return torchaudio.load(input)[0].numpy()
    except:
        waveform, _ = soundfile.read(input, dtype='float32')
        if waveform.ndim == 2:
            waveform = waveform[:, 0]
        return np.expand_dims(waveform, axis=0)
DATA_TYPES = {
    "sound": lambda x: torchaudio.load(x)[0][0].numpy(),
    "sound": load_wav,
    "pcm": load_pcm,
    "kaldi_ark": load_kaldi,
    "bytes": load_bytes,
@@ -106,10 +116,10 @@
            ] = None,
            float_dtype: str = "float32",
            fs: dict = None,
            mc: bool = False,
            int_dtype: str = "long",
            key_file: str = None,
    ):
        assert check_argument_types()
        if len(path_name_type_list) == 0:
            raise ValueError(
                '1 or more elements are required for "path_name_type_list"'
@@ -122,12 +132,13 @@
        self.int_dtype = int_dtype
        self.key_file = key_file
        self.fs = fs
        self.mc = mc
        self.debug_info = {}
        non_iterable_list = []
        self.path_name_type_list = []
        if not isinstance(path_name_type_list[0], Tuple):
        if not isinstance(path_name_type_list[0], (Tuple, List)):
            path = path_name_type_list[0]
            name = path_name_type_list[1]
            _type = path_name_type_list[2]
@@ -192,6 +203,7 @@
                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                       new_freq=model_fs)(array)
                        array = array.squeeze(0).numpy()
                data[name] = array
                if self.preprocess is not None:
@@ -224,13 +236,9 @@
                name = self.path_name_type_list[i][1]
                _type = self.path_name_type_list[i][2]
                if _type == "sound":
                    audio_type = os.path.basename(value).split(".")[1].lower()
                    if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
                        raise NotImplementedError(
                            f'Not supported audio type: {audio_type}')
                    if audio_type == "pcm":
                        _type = "pcm"
                   audio_type = os.path.basename(value).lower()
                   if audio_type.rfind(".pcm") >= 0:
                       _type = "pcm"
                func = DATA_TYPES[_type]
                array = func(value)
                if self.fs is not None and (name == "speech" or name == "ref_speech"):
@@ -238,11 +246,17 @@
                    model_fs = self.fs["model_fs"]
                    if audio_fs is not None and model_fs is not None:
                        array = torch.from_numpy(array)
                        array = array.unsqueeze(0)
                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                               new_freq=model_fs)(array)
                        array = array.squeeze(0).numpy()
                data[name] = array
                        array = array.numpy()
                if _type == "sound":
                    if self.mc:
                        data[name] = array.transpose((1, 0))
                    else:
                        data[name] = array[0]
                else:
                    data[name] = array
                if self.preprocess is not None:
                    data = self.preprocess(uid, data)
@@ -326,11 +340,8 @@
                # 2.a. Load data streamingly
                for value, (path, name, _type) in zip(values, self.path_name_type_list):
                    if _type == "sound":
                        audio_type = os.path.basename(value).split(".")[1].lower()
                        if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
                            raise NotImplementedError(
                                f'Not supported audio type: {audio_type}')
                        if audio_type == "pcm":
                        audio_type = os.path.basename(value).lower()
                        if audio_type.rfind(".pcm") >= 0:
                            _type = "pcm"
                    func = DATA_TYPES[_type]
                    # Load entry
@@ -340,11 +351,16 @@
                        model_fs = self.fs["model_fs"]
                        if audio_fs is not None and model_fs is not None:
                            array = torch.from_numpy(array)
                            array = array.unsqueeze(0)
                            array = torchaudio.transforms.Resample(orig_freq=audio_fs,
                                                                   new_freq=model_fs)(array)
                            array = array.squeeze(0).numpy()
                    data[name] = array
                            array = array.numpy()
                    if _type == "sound":
                        if self.mc:
                            data[name] = array.transpose((1, 0))
                        else:
                            data[name] = array[0]
                    else:
                        data[name] = array
                if self.non_iterable_dataset is not None:
                    # 2.b. Load data from non-iterable dataset
                    _, from_non_iterable = self.non_iterable_dataset[uid]
@@ -377,3 +393,4 @@
        if count == 0:
            raise RuntimeError("No iteration")