jmwang66
2023-05-16 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906
funasr/datasets/large_datasets/dataset.py
@@ -1,20 +1,20 @@
import logging
import os
import random
import numpy
from functools import partial
import torch
import torchaudio
import torch.distributed as dist
import torchaudio
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.filter import filter
from funasr.datasets.large_datasets.utils.padding import padding
from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.tokenize import tokenize
@@ -28,7 +28,8 @@
class AudioDataset(IterableDataset):
    def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"):
    def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None,
                 mode="train"):
        self.scp_lists = scp_lists
        self.data_names = data_names
        self.data_types = data_types
@@ -40,6 +41,9 @@
        self.world_size = 1
        self.worker_id = 0
        self.num_workers = 1
        self.speed_perturb = speed_perturb
        if self.speed_perturb is not None:
            logging.info("Using speed_perturb: {}".format(speed_perturb))
    def set_epoch(self, epoch):
        self.epoch = epoch
@@ -124,9 +128,15 @@
                            if sampling_rate != self.frontend_conf["fs"]:
                                waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
                                                                          new_freq=self.frontend_conf["fs"])(waveform)
                                sampling_rate = self.frontend_conf["fs"]
                                sampling_rate = self.frontend_conf["fs"]
                        waveform = waveform.numpy()
                        mat = waveform[0]
                        if self.speed_perturb is not None:
                            speed = random.choice(self.speed_perturb)
                            if speed != 1.0:
                                mat, _ = torchaudio.sox_effects.apply_effects_tensor(
                                    torch.tensor(mat).view(1, -1), sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]])
                                mat = mat.view(-1).numpy()
                        sample_dict[data_name] = mat
                        sample_dict["sampling_rate"] = sampling_rate
                        if data_name == "speech":
@@ -168,6 +178,7 @@
            bpe_tokenizer,
            conf,
            frontend_conf,
            speed_perturb=None,
            mode="train",
            batch_mode="padding"):
    scp_lists = read_lists(data_list_file)
@@ -196,7 +207,8 @@
                           data_names, 
                           data_types, 
                           frontend_conf=frontend_conf, 
                           shuffle=shuffle,
                           shuffle=shuffle,
                           speed_perturb=speed_perturb,
                           mode=mode, 
                           )