嘉渊
2023-05-28 3c3754dcc7568e76fa7d4b2c4e14849f68cc6ee7
funasr/datasets/large_datasets/dataset.py
@@ -1,20 +1,20 @@
import logging
import os
import random
import numpy
from functools import partial
import torch
import torchaudio
import torch.distributed as dist
import torchaudio
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.filter import filter
from funasr.datasets.large_datasets.utils.padding import padding
from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.tokenize import tokenize
@@ -28,7 +28,8 @@
class AudioDataset(IterableDataset):
    def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"):
    def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None,
                 mode="train"):
        self.scp_lists = scp_lists
        self.data_names = data_names
        self.data_types = data_types
@@ -40,6 +41,9 @@
        self.world_size = 1
        self.worker_id = 0
        self.num_workers = 1
        self.speed_perturb = speed_perturb
        if self.speed_perturb is not None:
            logging.info("Using speed_perturb: {}".format(speed_perturb))
    def set_epoch(self, epoch):
        self.epoch = epoch
@@ -101,7 +105,7 @@
                if data_type == "kaldi_ark":
                    ark_reader = ReadHelper('ark:{}'.format(data_file))
                    reader_list.append(ark_reader)
                elif data_type == "text" or data_type == "sound":
                elif data_type == "text" or data_type == "sound" or data_type == 'text_hotword':
                    text_reader = open(data_file, "r")
                    reader_list.append(text_reader)
                elif data_type == "none":
@@ -124,13 +128,32 @@
                            if sampling_rate != self.frontend_conf["fs"]:
                                waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
                                                                          new_freq=self.frontend_conf["fs"])(waveform)
                                sampling_rate = self.frontend_conf["fs"]
                                sampling_rate = self.frontend_conf["fs"]
                        waveform = waveform.numpy()
                        mat = waveform[0]
                        if self.speed_perturb is not None:
                            speed = random.choice(self.speed_perturb)
                            if speed != 1.0:
                                mat, _ = torchaudio.sox_effects.apply_effects_tensor(
                                    torch.tensor(mat).view(1, -1), sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]])
                                mat = mat.view(-1).numpy()
                        sample_dict[data_name] = mat
                        sample_dict["sampling_rate"] = sampling_rate
                        if data_name == "speech":
                            sample_dict["key"] = key
                    elif data_type == "text_hotword":
                        text = item
                        segs = text.strip().split()
                        sample_dict[data_name] = segs[1:]
                        if "key" not in sample_dict:
                            sample_dict["key"] = segs[0]
                        sample_dict['hw_tag'] = 1
                    elif data_type == "text_nospace":
                        text = item
                        segs = text.strip().split(maxsplit=1)
                        sample_dict[data_name] = [x for x in segs[1]]
                        if "key" not in sample_dict:
                            sample_dict["key"] = segs[0]
                    else:
                        text = item
                        segs = text.strip().split()
@@ -161,20 +184,46 @@
            bpe_tokenizer,
            conf,
            frontend_conf,
            speed_perturb=None,
            mode="train",
            batch_mode="padding"):
    scp_lists = read_lists(data_list_file)
    shuffle = conf.get('shuffle', True)
    data_names = conf.get("data_names", "speech,text")
    data_types = conf.get("data_types", "kaldi_ark,text")
    dataset = AudioDataset(scp_lists, data_names, data_types, frontend_conf=frontend_conf, shuffle=shuffle, mode=mode)
    pre_hwfile = conf.get("pre_hwlist", None)
    pre_prob = conf.get("pre_prob", 0)  # unused yet
    hw_config = {"sample_rate": conf.get("sample_rate", 0.6),
                 "double_rate": conf.get("double_rate", 0.1),
                 "hotword_min_length": conf.get("hotword_min_length", 2),
                 "hotword_max_length": conf.get("hotword_max_length", 8),
                 "pre_prob": conf.get("pre_prob", 0.0)}
    if pre_hwfile is not None:
        pre_hwlist = []
        with open(pre_hwfile, 'r') as fin:
            for line in fin.readlines():
                pre_hwlist.append(line.strip())
    else:
        pre_hwlist = None
    dataset = AudioDataset(scp_lists,
                           data_names,
                           data_types,
                           frontend_conf=frontend_conf,
                           shuffle=shuffle,
                           speed_perturb=speed_perturb,
                           mode=mode,
                           )
    filter_conf = conf.get('filter_conf', {})
    filter_fn = partial(filter, **filter_conf)
    dataset = FilterIterDataPipe(dataset, fn=filter_fn)
    if "text" in data_names:
        vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer}
        vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer, 'hw_config': hw_config}
        tokenize_fn = partial(tokenize, **vocab)
        dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)