| | |
| | | import logging |
| | | import os |
| | | import random |
| | | import numpy |
| | | from functools import partial |
| | | |
| | | import torch |
| | | import torchaudio |
| | | import torch.distributed as dist |
| | | import torchaudio |
| | | from kaldiio import ReadHelper |
| | | from torch.utils.data import IterableDataset |
| | | |
| | | from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe |
| | | from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe |
| | | from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe |
| | | from funasr.datasets.large_datasets.utils.clipping import clipping |
| | | from funasr.datasets.large_datasets.utils.filter import filter |
| | | from funasr.datasets.large_datasets.utils.padding import padding |
| | | from funasr.datasets.large_datasets.utils.clipping import clipping |
| | | from funasr.datasets.large_datasets.utils.tokenize import tokenize |
| | | |
| | | |
| | |
| | | |
| | | |
| | | class AudioDataset(IterableDataset): |
| | | def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"): |
| | | def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None, |
| | | mode="train"): |
| | | self.scp_lists = scp_lists |
| | | self.data_names = data_names |
| | | self.data_types = data_types |
| | |
| | | self.world_size = 1 |
| | | self.worker_id = 0 |
| | | self.num_workers = 1 |
| | | self.speed_perturb = speed_perturb |
| | | if self.speed_perturb is not None: |
| | | logging.info("Using speed_perturb: {}".format(speed_perturb)) |
| | | |
| | | def set_epoch(self, epoch): |
| | | self.epoch = epoch |
| | |
| | | if sampling_rate != self.frontend_conf["fs"]: |
| | | waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, |
| | | new_freq=self.frontend_conf["fs"])(waveform) |
| | | sampling_rate = self.frontend_conf["fs"] |
| | | sampling_rate = self.frontend_conf["fs"] |
| | | waveform = waveform.numpy() |
| | | mat = waveform[0] |
| | | if self.speed_perturb is not None: |
| | | speed = random.choice(self.speed_perturb) |
| | | if speed != 1.0: |
| | | mat, _ = torchaudio.sox_effects.apply_effects_tensor( |
| | | mat, sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]]) |
| | | sample_dict[data_name] = mat |
| | | sample_dict["sampling_rate"] = sampling_rate |
| | | if data_name == "speech": |
| | |
| | | bpe_tokenizer, |
| | | conf, |
| | | frontend_conf, |
| | | speed_perturb=None, |
| | | mode="train", |
| | | batch_mode="padding"): |
| | | scp_lists = read_lists(data_list_file) |
| | | shuffle = conf.get('shuffle', True) |
| | | data_names = conf.get("data_names", "speech,text") |
| | | data_types = conf.get("data_types", "kaldi_ark,text") |
| | | dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode) |
| | | dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, |
| | | speed_perturb=speed_perturb, mode=mode) |
| | | |
| | | filter_conf = conf.get('filter_conf', {}) |
| | | filter_fn = partial(filter, **filter_conf) |