From 9befa9e508d5ca95cb5faa29cd20d23e04e525c9 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 06 二月 2023 16:42:33 +0800
Subject: [PATCH] update data2vec pretrain: add clipping
---
funasr/datasets/large_datasets/dataset.py | 9 +
egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml | 14 ++
funasr/datasets/large_datasets/utils/clipping.py | 40 ++++++++
funasr/datasets/large_datasets/build_dataloader.py | 5
funasr/datasets/large_datasets/datapipes/batch.py | 218 ++++++++++++++++++++++++++++---------------
5 files changed, 204 insertions(+), 82 deletions(-)
diff --git a/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
index d7ddce6..4052774 100644
--- a/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
+++ b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
@@ -63,3 +63,17 @@
scheduler: tri_stage
scheduler_conf:
phase_ratio: [0.03,0.9,0.07]
+
+# for dataset
+dataset_conf:
+ batch_mode: clipping
+ data_names: speech,none
+ data_types: kaldi_ark,none
+ shuffle: true
+ shuffle_conf:
+ shuffle_size: 12800
+ sort_size: 12800
+ batch_conf:
+ batch_type: token
+ batch_size: 64000
+ num_workers: 8
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 146723d..8f7fd0b 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -35,15 +35,16 @@
class ArkDataLoader(AbsIterFactory):
def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"):
- symbol_table = read_symbol_table(dict_file)
+ symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
if seg_dict_file is not None:
seg_dict = load_seg_dict(seg_dict_file)
else:
seg_dict = None
self.dataset_conf = dataset_conf
logging.info("dataloader config: {}".format(self.dataset_conf))
+ batch_mode = self.dataset_conf.get("batch_mode", "padding")
self.dataset = Dataset(data_list, symbol_table, seg_dict,
- self.dataset_conf, mode=mode)
+ self.dataset_conf, mode=mode, batch_mode=batch_mode)
def build_iter(self, epoch, shuffle=True):
self.dataset.set_epoch(epoch)
diff --git a/funasr/datasets/large_datasets/datapipes/batch.py b/funasr/datasets/large_datasets/datapipes/batch.py
index 9c85d5e..c980ae3 100644
--- a/funasr/datasets/large_datasets/datapipes/batch.py
+++ b/funasr/datasets/large_datasets/datapipes/batch.py
@@ -24,7 +24,8 @@
batch_size=8000,
len_fn=_default_len_fn,
buffer_size=10240,
- sort_size=500
+ sort_size=500,
+ batch_mode="padding",
):
assert batch_size > 0, "Batch size is required to be larger than 0!"
assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
@@ -35,6 +36,7 @@
self.batch_size = batch_size
self.buffer_size = buffer_size
self.sort_size = sort_size
+ self.batch_mode = batch_mode
def set_epoch(self, epoch):
self.epoch = epoch
@@ -46,53 +48,134 @@
max_lengths = 0
batch_lengths = 0
- if self.buffer_size == -1:
- for d in self.datapipe:
- if d[0] > self.batch_size:
- continue
- buffer.append(d)
- buffer.sort()
- for sample in buffer:
- length, _, token = sample
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- bucket.append(batch)
- batch = []
- max_lengths = length
- batch.append(token)
- random.shuffle(bucket)
- if bucket:
- for batch_sample in bucket:
- yield batch_sample
- if batch:
- yield batch
-
- elif self.buffer_size == 0:
- for d in self.datapipe:
- if d[0] > self.batch_size:
- continue
- length, _, token = d
- if length > self.batch_size:
- continue
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- if batch:
- yield batch
-
- else:
+ if self.batch_mode == "clipping":
+ assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
if len(buffer) == self.buffer_size:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if buffer:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if bucket:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+
+ if batch:
+ yield batch
+
+ else:
+ if self.buffer_size == -1:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ buffer.sort()
+ for sample in buffer:
+ length, _, token = sample
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ bucket.append(batch)
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ random.shuffle(bucket)
+ if bucket:
+ for batch_sample in bucket:
+ yield batch_sample
+ if batch:
+ yield batch
+
+ elif self.buffer_size == 0:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ length, _, token = d
+ if length > self.batch_size:
+ continue
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ if batch:
+ yield batch
+
+ else:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ if len(buffer) == self.buffer_size:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if buffer:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
@@ -111,38 +194,19 @@
bucket = []
buffer = []
- if buffer:
- random.shuffle(buffer)
- for sample in buffer:
- bucket.append(sample)
- if len(bucket) == self.sort_size:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- bucket = []
- buffer = []
+ if bucket:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
- if bucket:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- bucket = []
-
- if batch:
- yield batch
+ if batch:
+ yield batch
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 41d34ab..81c1361 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -13,6 +13,7 @@
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
from funasr.datasets.large_datasets.utils.filter import filter
from funasr.datasets.large_datasets.utils.padding import padding
+from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.tokenize import tokenize
@@ -143,7 +144,8 @@
dict,
seg_dict,
conf,
- mode="train"):
+ mode="train",
+ batch_mode="padding"):
scp_lists = read_lists(data_list_file)
shuffle = conf.get('shuffle', True)
data_names = conf.get("data_names", "speech,text")
@@ -180,8 +182,9 @@
batch_size=batch_size,
len_fn=len_fn,
buffer_size=buffer_size,
- sort_size=sort_size)
+ sort_size=sort_size,
+ batch_mode=batch_mode)
- dataset = MapperIterDataPipe(dataset, fn=padding)
+ dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping)
return dataset
diff --git a/funasr/datasets/large_datasets/utils/clipping.py b/funasr/datasets/large_datasets/utils/clipping.py
new file mode 100644
index 0000000..f5c2940
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/clipping.py
@@ -0,0 +1,40 @@
+import numpy as np
+import torch
+
+from funasr.datasets.collate_fn import crop_to_max_size
+
+
+def clipping(data):
+ assert isinstance(data, list)
+ assert "key" in data[0]
+
+ keys = [x["key"] for x in data]
+
+ batch = {}
+ data_names = data[0].keys()
+ for data_name in data_names:
+ if data_name == "key":
+ continue
+ else:
+ if data[0][data_name].dtype.kind == "i":
+ tensor_type = torch.int64
+ else:
+ tensor_type = torch.float32
+
+ tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
+ tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
+
+ length_clip = min(tensor_lengths)
+ tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
+ for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
+ diff = length - length_clip
+ assert diff >= 0
+ if diff == 0:
+ tensor_clip[i] = tensor
+ else:
+ tensor_clip[i] = crop_to_max_size(tensor, length_clip)
+
+ batch[data_name] = tensor_clip
+ batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
+
+ return keys, batch
--
Gitblit v1.9.1