haoneng.lhn
2023-06-26 e677eb4b13b5388f4351a164a991cea950773a72
funasr/datasets/iterable_dataset.py
@@ -14,6 +14,7 @@
import numpy as np
import torch
import torchaudio
import soundfile
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
@@ -66,8 +67,14 @@
        bytes = f.read()
    return load_bytes(bytes)
def load_wav(input):
    try:
        return torchaudio.load(input)[0].numpy()
    except:
        return np.expand_dims(soundfile.read(input)[0], axis=0)
DATA_TYPES = {
    "sound": lambda x: torchaudio.load(x)[0].numpy(),
    "sound": load_wav,
    "pcm": load_pcm,
    "kaldi_ark": load_kaldi,
    "bytes": load_bytes,