From 9befa9e508d5ca95cb5faa29cd20d23e04e525c9 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 06 二月 2023 16:42:33 +0800
Subject: [PATCH] update data2vec pretrain: add clipping

---
 funasr/datasets/large_datasets/datapipes/batch.py |  218 +++++++++++++++++++++++++++++++++++-------------------
 1 files changed, 141 insertions(+), 77 deletions(-)

diff --git a/funasr/datasets/large_datasets/datapipes/batch.py b/funasr/datasets/large_datasets/datapipes/batch.py
index 9c85d5e..c980ae3 100644
--- a/funasr/datasets/large_datasets/datapipes/batch.py
+++ b/funasr/datasets/large_datasets/datapipes/batch.py
@@ -24,7 +24,8 @@
             batch_size=8000,
             len_fn=_default_len_fn,
             buffer_size=10240,
-            sort_size=500
+            sort_size=500,
+            batch_mode="padding",
     ):
         assert batch_size > 0, "Batch size is required to be larger than 0!"
         assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
@@ -35,6 +36,7 @@
         self.batch_size = batch_size
         self.buffer_size = buffer_size
         self.sort_size = sort_size
+        self.batch_mode = batch_mode
 
     def set_epoch(self, epoch):
         self.epoch = epoch
@@ -46,53 +48,134 @@
         max_lengths = 0
         batch_lengths = 0
 
-        if self.buffer_size == -1:
-            for d in self.datapipe:
-                if d[0] > self.batch_size:
-                    continue
-                buffer.append(d)
-            buffer.sort()
-            for sample in buffer:
-                length, _, token = sample
-                if length > max_lengths:
-                    max_lengths = length
-                batch_lengths = max_lengths * (len(batch) + 1)
-                if batch_lengths > self.batch_size:
-                    bucket.append(batch)
-                    batch = []
-                    max_lengths = length
-                batch.append(token)
-            random.shuffle(bucket)
-            if bucket:
-                for batch_sample in bucket:
-                    yield batch_sample
-            if batch:
-                yield batch
-
-        elif self.buffer_size == 0:
-            for d in self.datapipe:
-                if d[0] > self.batch_size:
-                    continue
-                length, _, token = d
-                if length > self.batch_size:
-                    continue
-                if length > max_lengths:
-                    max_lengths = length
-                batch_lengths = max_lengths * (len(batch) + 1)
-                if batch_lengths > self.batch_size:
-                    yield batch
-                    batch = []
-                    max_lengths = length
-                batch.append(token)
-            if batch:
-                yield batch
-
-        else:
+        if self.batch_mode == "clipping":
+            assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
             for d in self.datapipe:
                 if d[0] > self.batch_size:
                     continue
                 buffer.append(d)
                 if len(buffer) == self.buffer_size:
+                    random.shuffle(buffer)
+                    for sample in buffer:
+                        bucket.append(sample)
+                        if len(bucket) == self.sort_size:
+                            bucket.sort()
+                            for x in bucket:
+                                length, _, token = x
+                                if length < min_lengths:
+                                    min_lengths = length
+                                batch_lengths = min_lengths * (len(batch) + 1)
+                                if batch_lengths > self.batch_size:
+                                    yield batch
+                                    batch = []
+                                    min_lengths = length
+                                batch.append(token)
+                            bucket = []
+                    buffer = []
+
+            if buffer:
+                random.shuffle(buffer)
+                for sample in buffer:
+                    bucket.append(sample)
+                    if len(bucket) == self.sort_size:
+                        bucket.sort()
+                        for x in bucket:
+                            length, _, token = x
+                            if length < min_lengths:
+                                min_lengths = length
+                            batch_lengths = min_lengths * (len(batch) + 1)
+                            if batch_lengths > self.batch_size:
+                                yield batch
+                                batch = []
+                                min_lengths = length
+                            batch.append(token)
+                        bucket = []
+                buffer = []
+
+            if bucket:
+                bucket.sort()
+                for x in bucket:
+                    length, _, token = x
+                    if length < min_lengths:
+                        min_lengths = length
+                    batch_lengths = min_lengths * (len(batch) + 1)
+                    if batch_lengths > self.batch_size:
+                        yield batch
+                        batch = []
+                        min_lengths = length
+                    batch.append(token)
+                bucket = []
+
+            if batch:
+                yield batch
+
+        else:
+            if self.buffer_size == -1:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    buffer.append(d)
+                buffer.sort()
+                for sample in buffer:
+                    length, _, token = sample
+                    if length > max_lengths:
+                        max_lengths = length
+                    batch_lengths = max_lengths * (len(batch) + 1)
+                    if batch_lengths > self.batch_size:
+                        bucket.append(batch)
+                        batch = []
+                        max_lengths = length
+                    batch.append(token)
+                random.shuffle(bucket)
+                if bucket:
+                    for batch_sample in bucket:
+                        yield batch_sample
+                if batch:
+                    yield batch
+
+            elif self.buffer_size == 0:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    length, _, token = d
+                    if length > self.batch_size:
+                        continue
+                    if length > max_lengths:
+                        max_lengths = length
+                    batch_lengths = max_lengths * (len(batch) + 1)
+                    if batch_lengths > self.batch_size:
+                        yield batch
+                        batch = []
+                        max_lengths = length
+                    batch.append(token)
+                if batch:
+                    yield batch
+
+            else:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    buffer.append(d)
+                    if len(buffer) == self.buffer_size:
+                        random.shuffle(buffer)
+                        for sample in buffer:
+                            bucket.append(sample)
+                            if len(bucket) == self.sort_size:
+                                bucket.sort()
+                                for x in bucket:
+                                    length, _, token = x
+                                    if length > max_lengths:
+                                        max_lengths = length
+                                    batch_lengths = max_lengths * (len(batch) + 1)
+                                    if batch_lengths > self.batch_size:
+                                        yield batch
+                                        batch = []
+                                        max_lengths = length
+                                    batch.append(token)
+                                bucket = []
+                        buffer = []
+
+                if buffer:
                     random.shuffle(buffer)
                     for sample in buffer:
                         bucket.append(sample)
@@ -111,38 +194,19 @@
                             bucket = []
                     buffer = []
 
-            if buffer:
-                random.shuffle(buffer)
-                for sample in buffer:
-                    bucket.append(sample)
-                    if len(bucket) == self.sort_size:
-                        bucket.sort()
-                        for x in bucket:
-                            length, _, token = x
-                            if length > max_lengths:
-                                max_lengths = length
-                            batch_lengths = max_lengths * (len(batch) + 1)
-                            if batch_lengths > self.batch_size:
-                                yield batch
-                                batch = []
-                                max_lengths = length
-                            batch.append(token)
-                        bucket = []
-                buffer = []
+                if bucket:
+                    bucket.sort()
+                    for x in bucket:
+                        length, _, token = x
+                        if length > max_lengths:
+                            max_lengths = length
+                        batch_lengths = max_lengths * (len(batch) + 1)
+                        if batch_lengths > self.batch_size:
+                            yield batch
+                            batch = []
+                            max_lengths = length
+                        batch.append(token)
+                    bucket = []
 
-            if bucket:
-                bucket.sort()
-                for x in bucket:
-                    length, _, token = x
-                    if length > max_lengths:
-                        max_lengths = length
-                    batch_lengths = max_lengths * (len(batch) + 1)
-                    if batch_lengths > self.batch_size:
-                        yield batch
-                        batch = []
-                        max_lengths = length
-                    batch.append(token)
-                bucket = []
-
-            if batch:
-                yield batch
+                if batch:
+                    yield batch

--
Gitblit v1.9.1