From da9ac240cb3298248acb3262ed96b87fa3c1fa56 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 18 四月 2023 17:52:50 +0800
Subject: [PATCH] update

---
 funasr/datasets/small_datasets/length_batch_sampler.py |  147 +++++++++++++++++++++++++++++++++++++++++++++++++
 funasr/datasets/small_datasets/build_loader.py         |    6 -
 2 files changed, 148 insertions(+), 5 deletions(-)

diff --git a/funasr/datasets/small_datasets/build_loader.py b/funasr/datasets/small_datasets/build_loader.py
index d1bc21a..a7181a4 100644
--- a/funasr/datasets/small_datasets/build_loader.py
+++ b/funasr/datasets/small_datasets/build_loader.py
@@ -6,7 +6,7 @@
 
 from funasr.datasets.small_datasets.dataset import ESPnetDataset
 from funasr.datasets.small_datasets.preprocessor import build_preprocess
-from funasr.samplers.length_batch_sampler import LengthBatchSampler
+from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
 
 
 def build_dataloader(args, mode="train"):
@@ -26,10 +26,6 @@
         preprocess=preprocess_fn,
         dest_sample_rate=dest_sample_rate,
     )
-    if os.path.exists(os.path.join(data_path_and_name_and_type[0][0].parent, "utt2category")):
-        utt2category_file = os.path.join(data_path_and_name_and_type[0][0].parent, "utt2category")
-    else:
-        utt2category_file = None
 
     dataset_conf = args.dataset_conf
     batch_sampler = LengthBatchSampler(
diff --git a/funasr/datasets/small_datasets/length_batch_sampler.py b/funasr/datasets/small_datasets/length_batch_sampler.py
new file mode 100644
index 0000000..8ee8bdc
--- /dev/null
+++ b/funasr/datasets/small_datasets/length_batch_sampler.py
@@ -0,0 +1,147 @@
+from typing import Iterator
+from typing import List
+from typing import Dict
+from typing import Tuple
+from typing import Union
+
+from typeguard import check_argument_types
+
+from funasr.fileio.read_text import load_num_sequence_text
+from funasr.samplers.abs_sampler import AbsSampler
+
+
+class LengthBatchSampler(AbsSampler):
+    def __init__(
+        self,
+        batch_bins: int,
+        shape_files: Union[Tuple[str, ...], List[str], Dict],
+        min_batch_size: int = 1,
+        sort_in_batch: str = "descending",
+        sort_batch: str = "ascending",
+        drop_last: bool = False,
+        padding: bool = True,
+    ):
+        assert check_argument_types()
+        assert batch_bins > 0
+        if sort_batch != "ascending" and sort_batch != "descending":
+            raise ValueError(
+                f"sort_batch must be ascending or descending: {sort_batch}"
+            )
+        if sort_in_batch != "descending" and sort_in_batch != "ascending":
+            raise ValueError(
+                f"sort_in_batch must be ascending or descending: {sort_in_batch}"
+            )
+
+        self.batch_bins = batch_bins
+        self.shape_files = shape_files
+        self.sort_in_batch = sort_in_batch
+        self.sort_batch = sort_batch
+        self.drop_last = drop_last
+
+        # utt2shape: (Length, ...)
+        #    uttA 100,...
+        #    uttB 201,...
+        if isinstance(shape_files, dict):
+            utt2shapes = [shape_files]
+        else:
+            utt2shapes = [
+                load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
+            ]
+
+        first_utt2shape = utt2shapes[0]
+        for s, d in zip(shape_files, utt2shapes):
+            if set(d) != set(first_utt2shape):
+                raise RuntimeError(
+                    f"keys are mismatched between {s} != {shape_files[0]}"
+                )
+
+        # Sort samples in ascending order
+        # (shape order should be like (Length, Dim))
+        keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
+        if len(keys) == 0:
+            raise RuntimeError(f"0 lines found: {shape_files[0]}")
+
+        # Decide batch-sizes
+        batch_sizes = []
+        current_batch_keys = []
+        for key in keys:
+            current_batch_keys.append(key)
+            # shape: (Length, dim1, dim2, ...)
+            if padding:
+                # bins = bs x max_length
+                bins = sum(len(current_batch_keys) * sh[key][0] for sh in utt2shapes)
+            else:
+                # bins = sum of lengths
+                bins = sum(d[k][0] for k in current_batch_keys for d in utt2shapes)
+
+            if bins > batch_bins and len(current_batch_keys) >= min_batch_size:
+                batch_sizes.append(len(current_batch_keys))
+                current_batch_keys = []
+        else:
+            if len(current_batch_keys) != 0 and (
+                not self.drop_last or len(batch_sizes) == 0
+            ):
+                batch_sizes.append(len(current_batch_keys))
+
+        if len(batch_sizes) == 0:
+            # Maybe we can't reach here
+            raise RuntimeError("0 batches")
+
+        # If the last batch-size is smaller than minimum batch_size,
+        # the samples are redistributed to the other mini-batches
+        if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
+            for i in range(batch_sizes.pop(-1)):
+                batch_sizes[-(i % len(batch_sizes)) - 1] += 1
+
+        if not self.drop_last:
+            # Bug check
+            assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}"
+
+        # Set mini-batch
+        self.batch_list = []
+        iter_bs = iter(batch_sizes)
+        bs = next(iter_bs)
+        minibatch_keys = []
+        for key in keys:
+            minibatch_keys.append(key)
+            if len(minibatch_keys) == bs:
+                if sort_in_batch == "descending":
+                    minibatch_keys.reverse()
+                elif sort_in_batch == "ascending":
+                    # Key are already sorted in ascending
+                    pass
+                else:
+                    raise ValueError(
+                        "sort_in_batch must be ascending"
+                        f" or descending: {sort_in_batch}"
+                    )
+                self.batch_list.append(tuple(minibatch_keys))
+                minibatch_keys = []
+                try:
+                    bs = next(iter_bs)
+                except StopIteration:
+                    break
+
+        if sort_batch == "ascending":
+            pass
+        elif sort_batch == "descending":
+            self.batch_list.reverse()
+        else:
+            raise ValueError(
+                f"sort_batch must be ascending or descending: {sort_batch}"
+            )
+
+    def __repr__(self):
+        return (
+            f"{self.__class__.__name__}("
+            f"N-batch={len(self)}, "
+            f"batch_bins={self.batch_bins}, "
+            f"sort_in_batch={self.sort_in_batch}, "
+            f"sort_batch={self.sort_batch})"
+        )
+
+    def __len__(self):
+        return len(self.batch_list)
+
+    def __iter__(self) -> Iterator[Tuple[str, ...]]:
+        return iter(self.batch_list)

--
Gitblit v1.9.1