wucong.lyb
2023-06-29 d0718e276250e2f6e32eb3404d838dc3b1f1de7f
funasr/datasets/iterable_dataset.py
@@ -8,11 +8,13 @@
from typing import Iterator
from typing import Tuple
from typing import Union
from typing import List
import kaldiio
import numpy as np
import torch
import torchaudio
import soundfile
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
@@ -65,8 +67,17 @@
        bytes = f.read()
    return load_bytes(bytes)
def load_wav(input):
    try:
        return torchaudio.load(input)[0].numpy()
    except:
        waveform, _ = soundfile.read(input, dtype='float32')
        if waveform.ndim == 2:
            waveform = waveform[:, 0]
        return np.expand_dims(waveform, axis=0)
DATA_TYPES = {
    "sound": lambda x: torchaudio.load(x)[0].numpy(),
    "sound": load_wav,
    "pcm": load_pcm,
    "kaldi_ark": load_kaldi,
    "bytes": load_bytes,
@@ -129,7 +140,7 @@
        non_iterable_list = []
        self.path_name_type_list = []
        if not isinstance(path_name_type_list[0], Tuple):
        if not isinstance(path_name_type_list[0], (Tuple, List)):
            path = path_name_type_list[0]
            name = path_name_type_list[1]
            _type = path_name_type_list[2]
@@ -227,13 +238,9 @@
                name = self.path_name_type_list[i][1]
                _type = self.path_name_type_list[i][2]
                if _type == "sound":
                    audio_type = os.path.basename(value).split(".")[-1].lower()
                    if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
                        raise NotImplementedError(
                            f'Not supported audio type: {audio_type}')
                    if audio_type == "pcm":
                        _type = "pcm"
                   audio_type = os.path.basename(value).lower()
                   if audio_type.rfind(".pcm") >= 0:
                       _type = "pcm"
                func = DATA_TYPES[_type]
                array = func(value)
                if self.fs is not None and (name == "speech" or name == "ref_speech"):
@@ -335,11 +342,8 @@
                # 2.a. Load data streamingly
                for value, (path, name, _type) in zip(values, self.path_name_type_list):
                    if _type == "sound":
                        audio_type = os.path.basename(value).split(".")[-1].lower()
                        if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
                            raise NotImplementedError(
                                f'Not supported audio type: {audio_type}')
                        if audio_type == "pcm":
                        audio_type = os.path.basename(value).lower()
                        if audio_type.rfind(".pcm") >= 0:
                            _type = "pcm"
                    func = DATA_TYPES[_type]
                    # Load entry
@@ -391,3 +395,4 @@
        if count == 0:
            raise RuntimeError("No iteration")