From c880db53646ab9fd26417f4baf004ab44cc24e1a Mon Sep 17 00:00:00 2001
From: lingji-yidong <75744976+lingji-yidong@users.noreply.github.com>
Date: 星期五, 28 六月 2024 01:28:24 +0800
Subject: [PATCH] Fix: Return tuple ('', []) when char_list is empty to prevent ValueError (#1857)
---
funasr/datasets/large_datasets/datapipes/batch.py | 213 +++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 213 insertions(+), 0 deletions(-)
diff --git a/funasr/datasets/large_datasets/datapipes/batch.py b/funasr/datasets/large_datasets/datapipes/batch.py
new file mode 100644
index 0000000..aeeb451
--- /dev/null
+++ b/funasr/datasets/large_datasets/datapipes/batch.py
@@ -0,0 +1,213 @@
+import random
+
+from itertools import count
+from functools import partial
+from torch.utils.data import IterableDataset
+from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
+
+tiebreaker = count()
+
+
+def _default_len_fn(token):
+ return len(token), next(tiebreaker)
+
+
+def _token_len_fn(token, len_fn):
+ return len_fn(token), next(tiebreaker), token
+
+
+class MaxTokenBucketizerIterDataPipe(IterableDataset):
+
+ def __init__(
+ self,
+ datapipe,
+ batch_size=8000,
+ len_fn=_default_len_fn,
+ buffer_size=10240,
+ sort_size=500,
+ batch_mode="padding",
+ ):
+ assert batch_size > 0, "Batch size is required to be larger than 0!"
+ assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
+ assert sort_size > 0, "Sort size is required to be larger than 0!"
+
+ datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn))
+ self.datapipe = datapipe
+ self.batch_size = batch_size
+ self.buffer_size = buffer_size
+ self.sort_size = sort_size
+ self.batch_mode = batch_mode
+
+ def set_epoch(self, epoch):
+ self.datapipe.set_epoch(epoch)
+
+ def __iter__(self):
+ buffer = []
+ batch = []
+ bucket = []
+ max_lengths = 0
+ min_lengths = 999999
+ batch_lengths = 0
+
+ if self.batch_mode == "clipping":
+ assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ if len(buffer) == self.buffer_size:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if buffer:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if bucket:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+
+ if batch:
+ yield batch
+
+ else:
+ if self.buffer_size == -1:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ buffer.sort()
+ for sample in buffer:
+ length, _, token = sample
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ bucket.append(batch)
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ random.shuffle(bucket)
+ if bucket:
+ for batch_sample in bucket:
+ yield batch_sample
+ if batch:
+ yield batch
+
+ elif self.buffer_size == 0:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ length, _, token = d
+ if length > self.batch_size:
+ continue
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ if batch:
+ yield batch
+
+ else:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ if len(buffer) == self.buffer_size:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if buffer:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if bucket:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
+
+ if batch:
+ yield batch
--
Gitblit v1.9.1