游雁
2023-05-25 b18f7d121f2f17df8bf2d0c2bbb223bc5ddbcc0f
funasr/datasets/large_datasets/dataset.py
@@ -1,20 +1,20 @@
import logging
import os
import random
import numpy
from functools import partial
import torch
import torchaudio
import torch.distributed as dist
import torchaudio
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.filter import filter
from funasr.datasets.large_datasets.utils.padding import padding
from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.tokenize import tokenize
@@ -28,7 +28,8 @@
class AudioDataset(IterableDataset):
    def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"):
    def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None,
                 mode="train"):
        self.scp_lists = scp_lists
        self.data_names = data_names
        self.data_types = data_types
@@ -40,6 +41,9 @@
        self.world_size = 1
        self.worker_id = 0
        self.num_workers = 1
        self.speed_perturb = speed_perturb
        if self.speed_perturb is not None:
            logging.info("Using speed_perturb: {}".format(speed_perturb))
    def set_epoch(self, epoch):
        self.epoch = epoch
@@ -101,7 +105,7 @@
                if data_type == "kaldi_ark":
                    ark_reader = ReadHelper('ark:{}'.format(data_file))
                    reader_list.append(ark_reader)
                elif data_type == "text" or data_type == "sound":
                elif data_type == "text" or data_type == "sound" or data_type == 'text_hotword':
                    text_reader = open(data_file, "r")
                    reader_list.append(text_reader)
                elif data_type == "none":
@@ -124,9 +128,15 @@
                            if sampling_rate != self.frontend_conf["fs"]:
                                waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
                                                                          new_freq=self.frontend_conf["fs"])(waveform)
                                sampling_rate = self.frontend_conf["fs"]
                                sampling_rate = self.frontend_conf["fs"]
                        waveform = waveform.numpy()
                        mat = waveform[0]
                        if self.speed_perturb is not None:
                            speed = random.choice(self.speed_perturb)
                            if speed != 1.0:
                                mat, _ = torchaudio.sox_effects.apply_effects_tensor(
                                    torch.tensor(mat).view(1, -1), sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]])
                                mat = mat.view(-1).numpy()
                        sample_dict[data_name] = mat
                        sample_dict["sampling_rate"] = sampling_rate
                        if data_name == "speech":
@@ -168,6 +178,7 @@
            bpe_tokenizer,
            conf,
            frontend_conf,
            speed_perturb=None,
            mode="train",
            batch_mode="padding"):
    scp_lists = read_lists(data_list_file)
@@ -196,7 +207,8 @@
                           data_names, 
                           data_types, 
                           frontend_conf=frontend_conf, 
                           shuffle=shuffle,
                           shuffle=shuffle,
                           speed_perturb=speed_perturb,
                           mode=mode, 
                           )