From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001 From: 游雁 <zhifu.gzf@alibaba-inc.com> Date: 星期三, 27 三月 2024 16:05:29 +0800 Subject: [PATCH] train update --- funasr/datasets/large_datasets/datapipes/batch.py | 213 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 files changed, 213 insertions(+), 0 deletions(-) diff --git a/funasr/datasets/large_datasets/datapipes/batch.py b/funasr/datasets/large_datasets/datapipes/batch.py new file mode 100644 index 0000000..35e5dba --- /dev/null +++ b/funasr/datasets/large_datasets/datapipes/batch.py @@ -0,0 +1,213 @@ +import random + +from itertools import count +from functools import partial +from torch.utils.data import IterableDataset +from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe + +tiebreaker = count() + + +def _default_len_fn(token): + return len(token), next(tiebreaker) + + +def _token_len_fn(token, len_fn): + return len_fn(token), next(tiebreaker), token + + +class MaxTokenBucketizerIterDataPipe(IterableDataset): + + def __init__( + self, + datapipe, + batch_size=8000, + len_fn=_default_len_fn, + buffer_size=10240, + sort_size=500, + batch_mode="padding", + ): + assert batch_size > 0, "Batch size is required to be larger than 0!" + assert buffer_size >= -1, "Buffer size is required to be larger than -1!" + assert sort_size > 0, "Sort size is required to be larger than 0!" + + datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn)) + self.datapipe = datapipe + self.batch_size = batch_size + self.buffer_size = buffer_size + self.sort_size = sort_size + self.batch_mode = batch_mode + + def set_epoch(self, epoch): + self.datapipe.set_epoch(epoch) + + def __iter__(self): + buffer = [] + batch = [] + bucket = [] + max_lengths = 0 + min_lengths = 999999 + batch_lengths = 0 + + if self.batch_mode == "clipping": + assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1" + for d in self.datapipe: + if d[0] > self.batch_size: + continue + buffer.append(d) + if len(buffer) == self.buffer_size: + random.shuffle(buffer) + for sample in buffer: + bucket.append(sample) + if len(bucket) == self.sort_size: + bucket.sort() + for x in bucket: + length, _, token = x + if length < min_lengths: + min_lengths = length + batch_lengths = min_lengths * (len(batch) + 1) + if batch_lengths > self.batch_size: + yield batch + batch = [] + min_lengths = length + batch.append(token) + bucket = [] + buffer = [] + + if buffer: + random.shuffle(buffer) + for sample in buffer: + bucket.append(sample) + if len(bucket) == self.sort_size: + bucket.sort() + for x in bucket: + length, _, token = x + if length < min_lengths: + min_lengths = length + batch_lengths = min_lengths * (len(batch) + 1) + if batch_lengths > self.batch_size: + yield batch + batch = [] + min_lengths = length + batch.append(token) + bucket = [] + buffer = [] + + if bucket: + bucket.sort() + for x in bucket: + length, _, token = x + if length < min_lengths: + min_lengths = length + batch_lengths = min_lengths * (len(batch) + 1) + if batch_lengths > self.batch_size: + yield batch + batch = [] + min_lengths = length + batch.append(token) + bucket = [] + + if batch: + yield batch + + else: + if self.buffer_size == -1: + for d in self.datapipe: + if d[0] > self.batch_size: + continue + buffer.append(d) + buffer.sort() + for sample in buffer: + length, _, token = sample + if length > max_lengths: + max_lengths = length + batch_lengths = max_lengths * (len(batch) + 1) + if batch_lengths > self.batch_size: + bucket.append(batch) + batch = [] + max_lengths = length + batch.append(token) + random.shuffle(bucket) + if bucket: + for batch_sample in bucket: + yield batch_sample + if batch: + yield batch + + elif self.buffer_size == 0: + for d in self.datapipe: + if d[0] > self.batch_size: + continue + length, _, token = d + if length > self.batch_size: + continue + if length > max_lengths: + max_lengths = length + batch_lengths = max_lengths * (len(batch) + 1) + if batch_lengths > self.batch_size: + yield batch + batch = [] + max_lengths = length + batch.append(token) + if batch: + yield batch + + else: + for d in self.datapipe: + if d[0] > self.batch_size: + continue + buffer.append(d) + if len(buffer) == self.buffer_size: + random.shuffle(buffer) + for sample in buffer: + bucket.append(sample) + if len(bucket) == self.sort_size: + bucket.sort() + for x in bucket: + length, _, token = x + if length > max_lengths: + max_lengths = length + batch_lengths = max_lengths * (len(batch) + 1) + if batch_lengths > self.batch_size: + yield batch + batch = [] + max_lengths = length + batch.append(token) + bucket = [] + buffer = [] + + if buffer: + random.shuffle(buffer) + for sample in buffer: + bucket.append(sample) + if len(bucket) == self.sort_size: + bucket.sort() + for x in bucket: + length, _, token = x + if length > max_lengths: + max_lengths = length + batch_lengths = max_lengths * (len(batch) + 1) + if batch_lengths > self.batch_size: + yield batch + batch = [] + max_lengths = length + batch.append(token) + bucket = [] + buffer = [] + + if bucket: + bucket.sort() + for x in bucket: + length, _, token = x + if length > max_lengths: + max_lengths = length + batch_lengths = max_lengths * (len(batch) + 1) + if batch_lengths > self.batch_size: + yield batch + batch = [] + max_lengths = length + batch.append(token) + bucket = [] + + if batch: + yield batch -- Gitblit v1.9.1