From da9ac240cb3298248acb3262ed96b87fa3c1fa56 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 18 四月 2023 17:52:50 +0800
Subject: [PATCH] update
---
funasr/datasets/small_datasets/length_batch_sampler.py | 147 +++++++++++++++++++++++++++++++++++++++++++++++++
funasr/datasets/small_datasets/build_loader.py | 6 -
2 files changed, 148 insertions(+), 5 deletions(-)
diff --git a/funasr/datasets/small_datasets/build_loader.py b/funasr/datasets/small_datasets/build_loader.py
index d1bc21a..a7181a4 100644
--- a/funasr/datasets/small_datasets/build_loader.py
+++ b/funasr/datasets/small_datasets/build_loader.py
@@ -6,7 +6,7 @@
from funasr.datasets.small_datasets.dataset import ESPnetDataset
from funasr.datasets.small_datasets.preprocessor import build_preprocess
-from funasr.samplers.length_batch_sampler import LengthBatchSampler
+from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
def build_dataloader(args, mode="train"):
@@ -26,10 +26,6 @@
preprocess=preprocess_fn,
dest_sample_rate=dest_sample_rate,
)
- if os.path.exists(os.path.join(data_path_and_name_and_type[0][0].parent, "utt2category")):
- utt2category_file = os.path.join(data_path_and_name_and_type[0][0].parent, "utt2category")
- else:
- utt2category_file = None
dataset_conf = args.dataset_conf
batch_sampler = LengthBatchSampler(
diff --git a/funasr/datasets/small_datasets/length_batch_sampler.py b/funasr/datasets/small_datasets/length_batch_sampler.py
new file mode 100644
index 0000000..8ee8bdc
--- /dev/null
+++ b/funasr/datasets/small_datasets/length_batch_sampler.py
@@ -0,0 +1,147 @@
+from typing import Iterator
+from typing import List
+from typing import Dict
+from typing import Tuple
+from typing import Union
+
+from typeguard import check_argument_types
+
+from funasr.fileio.read_text import load_num_sequence_text
+from funasr.samplers.abs_sampler import AbsSampler
+
+
+class LengthBatchSampler(AbsSampler):
+ def __init__(
+ self,
+ batch_bins: int,
+ shape_files: Union[Tuple[str, ...], List[str], Dict],
+ min_batch_size: int = 1,
+ sort_in_batch: str = "descending",
+ sort_batch: str = "ascending",
+ drop_last: bool = False,
+ padding: bool = True,
+ ):
+ assert check_argument_types()
+ assert batch_bins > 0
+ if sort_batch != "ascending" and sort_batch != "descending":
+ raise ValueError(
+ f"sort_batch must be ascending or descending: {sort_batch}"
+ )
+ if sort_in_batch != "descending" and sort_in_batch != "ascending":
+ raise ValueError(
+ f"sort_in_batch must be ascending or descending: {sort_in_batch}"
+ )
+
+ self.batch_bins = batch_bins
+ self.shape_files = shape_files
+ self.sort_in_batch = sort_in_batch
+ self.sort_batch = sort_batch
+ self.drop_last = drop_last
+
+ # utt2shape: (Length, ...)
+ # uttA 100,...
+ # uttB 201,...
+ if isinstance(shape_files, dict):
+ utt2shapes = [shape_files]
+ else:
+ utt2shapes = [
+ load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
+ ]
+
+ first_utt2shape = utt2shapes[0]
+ for s, d in zip(shape_files, utt2shapes):
+ if set(d) != set(first_utt2shape):
+ raise RuntimeError(
+ f"keys are mismatched between {s} != {shape_files[0]}"
+ )
+
+ # Sort samples in ascending order
+ # (shape order should be like (Length, Dim))
+ keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
+ if len(keys) == 0:
+ raise RuntimeError(f"0 lines found: {shape_files[0]}")
+
+ # Decide batch-sizes
+ batch_sizes = []
+ current_batch_keys = []
+ for key in keys:
+ current_batch_keys.append(key)
+ # shape: (Length, dim1, dim2, ...)
+ if padding:
+ # bins = bs x max_length
+ bins = sum(len(current_batch_keys) * sh[key][0] for sh in utt2shapes)
+ else:
+ # bins = sum of lengths
+ bins = sum(d[k][0] for k in current_batch_keys for d in utt2shapes)
+
+ if bins > batch_bins and len(current_batch_keys) >= min_batch_size:
+ batch_sizes.append(len(current_batch_keys))
+ current_batch_keys = []
+ else:
+ if len(current_batch_keys) != 0 and (
+ not self.drop_last or len(batch_sizes) == 0
+ ):
+ batch_sizes.append(len(current_batch_keys))
+
+ if len(batch_sizes) == 0:
+ # Maybe we can't reach here
+ raise RuntimeError("0 batches")
+
+ # If the last batch-size is smaller than minimum batch_size,
+ # the samples are redistributed to the other mini-batches
+ if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
+ for i in range(batch_sizes.pop(-1)):
+ batch_sizes[-(i % len(batch_sizes)) - 1] += 1
+
+ if not self.drop_last:
+ # Bug check
+ assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}"
+
+ # Set mini-batch
+ self.batch_list = []
+ iter_bs = iter(batch_sizes)
+ bs = next(iter_bs)
+ minibatch_keys = []
+ for key in keys:
+ minibatch_keys.append(key)
+ if len(minibatch_keys) == bs:
+ if sort_in_batch == "descending":
+ minibatch_keys.reverse()
+ elif sort_in_batch == "ascending":
+ # Key are already sorted in ascending
+ pass
+ else:
+ raise ValueError(
+ "sort_in_batch must be ascending"
+ f" or descending: {sort_in_batch}"
+ )
+ self.batch_list.append(tuple(minibatch_keys))
+ minibatch_keys = []
+ try:
+ bs = next(iter_bs)
+ except StopIteration:
+ break
+
+ if sort_batch == "ascending":
+ pass
+ elif sort_batch == "descending":
+ self.batch_list.reverse()
+ else:
+ raise ValueError(
+ f"sort_batch must be ascending or descending: {sort_batch}"
+ )
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}("
+ f"N-batch={len(self)}, "
+ f"batch_bins={self.batch_bins}, "
+ f"sort_in_batch={self.sort_in_batch}, "
+ f"sort_batch={self.sort_batch})"
+ )
+
+ def __len__(self):
+ return len(self.batch_list)
+
+ def __iter__(self) -> Iterator[Tuple[str, ...]]:
+ return iter(self.batch_list)
--
Gitblit v1.9.1