jmwang66
2023-02-06 9befa9e508d5ca95cb5faa29cd20d23e04e525c9
funasr/datasets/large_datasets/datapipes/batch.py
@@ -24,7 +24,8 @@
            batch_size=8000,
            len_fn=_default_len_fn,
            buffer_size=10240,
            sort_size=500
            sort_size=500,
            batch_mode="padding",
    ):
        assert batch_size > 0, "Batch size is required to be larger than 0!"
        assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
@@ -35,6 +36,7 @@
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.sort_size = sort_size
        self.batch_mode = batch_mode
    def set_epoch(self, epoch):
        self.epoch = epoch
@@ -46,6 +48,68 @@
        max_lengths = 0
        batch_lengths = 0
        if self.batch_mode == "clipping":
            assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
            for d in self.datapipe:
                if d[0] > self.batch_size:
                    continue
                buffer.append(d)
                if len(buffer) == self.buffer_size:
                    random.shuffle(buffer)
                    for sample in buffer:
                        bucket.append(sample)
                        if len(bucket) == self.sort_size:
                            bucket.sort()
                            for x in bucket:
                                length, _, token = x
                                if length < min_lengths:
                                    min_lengths = length
                                batch_lengths = min_lengths * (len(batch) + 1)
                                if batch_lengths > self.batch_size:
                                    yield batch
                                    batch = []
                                    min_lengths = length
                                batch.append(token)
                            bucket = []
                    buffer = []
            if buffer:
                random.shuffle(buffer)
                for sample in buffer:
                    bucket.append(sample)
                    if len(bucket) == self.sort_size:
                        bucket.sort()
                        for x in bucket:
                            length, _, token = x
                            if length < min_lengths:
                                min_lengths = length
                            batch_lengths = min_lengths * (len(batch) + 1)
                            if batch_lengths > self.batch_size:
                                yield batch
                                batch = []
                                min_lengths = length
                            batch.append(token)
                        bucket = []
                buffer = []
            if bucket:
                bucket.sort()
                for x in bucket:
                    length, _, token = x
                    if length < min_lengths:
                        min_lengths = length
                    batch_lengths = min_lengths * (len(batch) + 1)
                    if batch_lengths > self.batch_size:
                        yield batch
                        batch = []
                        min_lengths = length
                    batch.append(token)
                bucket = []
            if batch:
                yield batch
        else:
        if self.buffer_size == -1:
            for d in self.datapipe:
                if d[0] > self.batch_size: