update data2vec pretrain: add clipping
| | |
| | | scheduler: tri_stage |
| | | scheduler_conf: |
| | | phase_ratio: [0.03,0.9,0.07] |
| | | |
| | | # for dataset |
| | | dataset_conf: |
| | | batch_mode: clipping |
| | | data_names: speech,none |
| | | data_types: kaldi_ark,none |
| | | shuffle: true |
| | | shuffle_conf: |
| | | shuffle_size: 12800 |
| | | sort_size: 12800 |
| | | batch_conf: |
| | | batch_type: token |
| | | batch_size: 64000 |
| | | num_workers: 8 |
| | |
| | | |
| | | class ArkDataLoader(AbsIterFactory): |
| | | def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"): |
| | | symbol_table = read_symbol_table(dict_file) |
| | | symbol_table = read_symbol_table(dict_file) if dict_file is not None else None |
| | | if seg_dict_file is not None: |
| | | seg_dict = load_seg_dict(seg_dict_file) |
| | | else: |
| | | seg_dict = None |
| | | self.dataset_conf = dataset_conf |
| | | logging.info("dataloader config: {}".format(self.dataset_conf)) |
| | | batch_mode = self.dataset_conf.get("batch_mode", "padding") |
| | | self.dataset = Dataset(data_list, symbol_table, seg_dict, |
| | | self.dataset_conf, mode=mode) |
| | | self.dataset_conf, mode=mode, batch_mode=batch_mode) |
| | | |
| | | def build_iter(self, epoch, shuffle=True): |
| | | self.dataset.set_epoch(epoch) |
| | |
| | | batch_size=8000, |
| | | len_fn=_default_len_fn, |
| | | buffer_size=10240, |
| | | sort_size=500 |
| | | sort_size=500, |
| | | batch_mode="padding", |
| | | ): |
| | | assert batch_size > 0, "Batch size is required to be larger than 0!" |
| | | assert buffer_size >= -1, "Buffer size is required to be larger than -1!" |
| | |
| | | self.batch_size = batch_size |
| | | self.buffer_size = buffer_size |
| | | self.sort_size = sort_size |
| | | self.batch_mode = batch_mode |
| | | |
| | | def set_epoch(self, epoch): |
| | | self.epoch = epoch |
| | |
| | | max_lengths = 0 |
| | | batch_lengths = 0 |
| | | |
| | | if self.batch_mode == "clipping": |
| | | assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1" |
| | | for d in self.datapipe: |
| | | if d[0] > self.batch_size: |
| | | continue |
| | | buffer.append(d) |
| | | if len(buffer) == self.buffer_size: |
| | | random.shuffle(buffer) |
| | | for sample in buffer: |
| | | bucket.append(sample) |
| | | if len(bucket) == self.sort_size: |
| | | bucket.sort() |
| | | for x in bucket: |
| | | length, _, token = x |
| | | if length < min_lengths: |
| | | min_lengths = length |
| | | batch_lengths = min_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | min_lengths = length |
| | | batch.append(token) |
| | | bucket = [] |
| | | buffer = [] |
| | | |
| | | if buffer: |
| | | random.shuffle(buffer) |
| | | for sample in buffer: |
| | | bucket.append(sample) |
| | | if len(bucket) == self.sort_size: |
| | | bucket.sort() |
| | | for x in bucket: |
| | | length, _, token = x |
| | | if length < min_lengths: |
| | | min_lengths = length |
| | | batch_lengths = min_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | min_lengths = length |
| | | batch.append(token) |
| | | bucket = [] |
| | | buffer = [] |
| | | |
| | | if bucket: |
| | | bucket.sort() |
| | | for x in bucket: |
| | | length, _, token = x |
| | | if length < min_lengths: |
| | | min_lengths = length |
| | | batch_lengths = min_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | min_lengths = length |
| | | batch.append(token) |
| | | bucket = [] |
| | | |
| | | if batch: |
| | | yield batch |
| | | |
| | | else: |
| | | if self.buffer_size == -1: |
| | | for d in self.datapipe: |
| | | if d[0] > self.batch_size: |
| | |
| | | from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe |
| | | from funasr.datasets.large_datasets.utils.filter import filter |
| | | from funasr.datasets.large_datasets.utils.padding import padding |
| | | from funasr.datasets.large_datasets.utils.clipping import clipping |
| | | from funasr.datasets.large_datasets.utils.tokenize import tokenize |
| | | |
| | | |
| | |
| | | dict, |
| | | seg_dict, |
| | | conf, |
| | | mode="train"): |
| | | mode="train", |
| | | batch_mode="padding"): |
| | | scp_lists = read_lists(data_list_file) |
| | | shuffle = conf.get('shuffle', True) |
| | | data_names = conf.get("data_names", "speech,text") |
| | |
| | | batch_size=batch_size, |
| | | len_fn=len_fn, |
| | | buffer_size=buffer_size, |
| | | sort_size=sort_size) |
| | | sort_size=sort_size, |
| | | batch_mode=batch_mode) |
| | | |
| | | dataset = MapperIterDataPipe(dataset, fn=padding) |
| | | dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping) |
| | | |
| | | return dataset |
| New file |
| | |
| | | import numpy as np |
| | | import torch |
| | | |
| | | from funasr.datasets.collate_fn import crop_to_max_size |
| | | |
| | | |
| | | def clipping(data): |
| | | assert isinstance(data, list) |
| | | assert "key" in data[0] |
| | | |
| | | keys = [x["key"] for x in data] |
| | | |
| | | batch = {} |
| | | data_names = data[0].keys() |
| | | for data_name in data_names: |
| | | if data_name == "key": |
| | | continue |
| | | else: |
| | | if data[0][data_name].dtype.kind == "i": |
| | | tensor_type = torch.int64 |
| | | else: |
| | | tensor_type = torch.float32 |
| | | |
| | | tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data] |
| | | tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32) |
| | | |
| | | length_clip = min(tensor_lengths) |
| | | tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1]) |
| | | for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)): |
| | | diff = length - length_clip |
| | | assert diff >= 0 |
| | | if diff == 0: |
| | | tensor_clip[i] = tensor |
| | | else: |
| | | tensor_clip[i] = crop_to_max_size(tensor, length_clip) |
| | | |
| | | batch[data_name] = tensor_clip |
| | | batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long) |
| | | |
| | | return keys, batch |