雾聪
2023-06-28 54931dd4e1a099d7d6f144c4e12e5453deb3aa26
funasr/datasets/large_datasets/dataset.py
@@ -6,6 +6,8 @@
import torch
import torch.distributed as dist
import torchaudio
import numpy as np
import soundfile
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
@@ -123,7 +125,14 @@
                            sample_dict["key"] = key
                    elif data_type == "sound":
                        key, path = item.strip().split()
                        waveform, sampling_rate = torchaudio.load(path)
                        try:
                            waveform, sampling_rate = torchaudio.load(path)
                        except:
                            waveform, sampling_rate = soundfile.read(path, dtype='float32')
                            if waveform.ndim == 2:
                                waveform = waveform[:, 0]
                            waveform = np.expand_dims(waveform, axis=0)
                            waveform = torch.tensor(waveform)
                        if self.frontend_conf is not None:
                            if sampling_rate != self.frontend_conf["fs"]:
                                waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
@@ -148,6 +157,12 @@
                        if "key" not in sample_dict:
                            sample_dict["key"] = segs[0]
                        sample_dict['hw_tag'] = 1
                    elif data_type == "text_nospace":
                        text = item
                        segs = text.strip().split(maxsplit=1)
                        sample_dict[data_name] = [x for x in segs[1]]
                        if "key" not in sample_dict:
                            sample_dict["key"] = segs[0]
                    else:
                        text = item
                        segs = text.strip().split()