From 9befa9e508d5ca95cb5faa29cd20d23e04e525c9 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 06 二月 2023 16:42:33 +0800
Subject: [PATCH] update data2vec pretrain: add clipping
---
funasr/datasets/large_datasets/datapipes/batch.py | 218 +++++++++++++++++++++++++++++++++++-------------------
1 files changed, 141 insertions(+), 77 deletions(-)
diff --git a/funasr/datasets/large_datasets/datapipes/batch.py b/funasr/datasets/large_datasets/datapipes/batch.py
index 9c85d5e..c980ae3 100644
--- a/funasr/datasets/large_datasets/datapipes/batch.py
+++ b/funasr/datasets/large_datasets/datapipes/batch.py
@@ -24,7 +24,8 @@
batch_size=8000,
len_fn=_default_len_fn,
buffer_size=10240,
- sort_size=500
+ sort_size=500,
+ batch_mode="padding",
):
assert batch_size > 0, "Batch size is required to be larger than 0!"
assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
@@ -35,6 +36,7 @@
self.batch_size = batch_size
self.buffer_size = buffer_size
self.sort_size = sort_size
+ self.batch_mode = batch_mode
def set_epoch(self, epoch):
self.epoch = epoch
@@ -46,53 +48,134 @@
max_lengths = 0
batch_lengths = 0
- if self.buffer_size == -1:
- for d in self.datapipe:
- if d[0] > self.batch_size:
- continue
- buffer.append(d)
- buffer.sort()
- for sample in buffer:
- length, _, token = sample
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- bucket.append(batch)
- batch = []
- max_lengths = length
- batch.append(token)
- random.shuffle(bucket)
- if bucket:
- for batch_sample in bucket:
- yield batch_sample
- if batch:
- yield batch
-
- elif self.buffer_size == 0:
- for d in self.datapipe:
- if d[0] > self.batch_size:
- continue
- length, _, token = d
- if length > self.batch_size:
- continue
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- if batch:
- yield batch
-
- else:
+ if self.batch_mode == "clipping":
+ assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
if len(buffer) == self.buffer_size:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if buffer:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if bucket:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+
+ if batch:
+ yield batch
+
+ else:
+ if self.buffer_size == -1:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ buffer.sort()
+ for sample in buffer:
+ length, _, token = sample
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ bucket.append(batch)
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ random.shuffle(bucket)
+ if bucket:
+ for batch_sample in bucket:
+ yield batch_sample
+ if batch:
+ yield batch
+
+ elif self.buffer_size == 0:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ length, _, token = d
+ if length > self.batch_size:
+ continue
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ if batch:
+ yield batch
+
+ else:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ if len(buffer) == self.buffer_size:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if buffer:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
@@ -111,38 +194,19 @@
bucket = []
buffer = []
- if buffer:
- random.shuffle(buffer)
- for sample in buffer:
- bucket.append(sample)
- if len(bucket) == self.sort_size:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- bucket = []
- buffer = []
+ if bucket:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
- if bucket:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- bucket = []
-
- if batch:
- yield batch
+ if batch:
+ yield batch
--
Gitblit v1.9.1