From 97a689d65da434345a641a909f13b78e5690c86b Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 18 五月 2023 19:35:08 +0800
Subject: [PATCH] Merge pull request #526 from alibaba-damo-academy/dev_infer

---
 funasr/datasets/small_datasets/sequence_iter_factory.py |  189 +++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 189 insertions(+), 0 deletions(-)

diff --git a/funasr/datasets/small_datasets/sequence_iter_factory.py b/funasr/datasets/small_datasets/sequence_iter_factory.py
new file mode 100644
index 0000000..3ebcc5a
--- /dev/null
+++ b/funasr/datasets/small_datasets/sequence_iter_factory.py
@@ -0,0 +1,189 @@
+import logging
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
+from funasr.datasets.small_datasets.dataset import ESPnetDataset
+from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
+from funasr.datasets.small_datasets.preprocessor import build_preprocess
+from funasr.iterators.abs_iter_factory import AbsIterFactory
+from funasr.samplers.abs_sampler import AbsSampler
+
+
+class RawSampler(AbsSampler):
+    def __init__(self, batches):
+        self.batches = batches
+
+    def __len__(self):
+        return len(self.batches)
+
+    def __iter__(self):
+        return iter(self.batches)
+
+    def generate(self, seed):
+        return list(self.batches)
+
+
+class SequenceIterFactory(AbsIterFactory):
+    """Build iterator for each epoch, modified from ESPnet
+
+    """
+
+    def __init__(self, args, mode="train"):
+
+        # preprocess
+        preprocess_fn = build_preprocess(args, train=mode == "train")
+
+        # collate
+        if args.task_name in ["punc", "lm"]:
+            collate_fn = CommonCollateFn(int_pad_value=0)
+        else:
+            collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+        # dataset
+        dest_sample_rate = args.frontend_conf["fs"] if (
+                args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
+        if mode == "train":
+            data_path_and_name_and_type = args.train_data_path_and_name_and_type
+            shape_files = args.train_shape_file
+        elif mode == "valid":
+            data_path_and_name_and_type = args.valid_data_path_and_name_and_type
+            shape_files = args.valid_shape_file
+        else:
+            raise NotImplementedError(f"mode={mode}")
+        dataset = ESPnetDataset(
+            data_path_and_name_and_type,
+            preprocess=preprocess_fn,
+            dest_sample_rate=dest_sample_rate,
+            speed_perturb=args.speed_perturb if mode=="train" else None,
+        )
+
+        # sampler
+        dataset_conf = args.dataset_conf
+        batch_sampler = LengthBatchSampler(
+            batch_bins=dataset_conf["batch_conf"]["batch_size"] * args.ngpu,
+            shape_files=shape_files,
+            sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending",
+            sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending",
+            drop_last=False,
+            padding=True,
+        )
+
+        batches = list(batch_sampler)
+        bs_list = [len(batch) for batch in batches]
+        logging.info(f"[{mode}] dataset:\n{dataset}")
+        logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
+        logging.info(
+            f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
+            f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
+        )
+
+        if args.scheduler == "tri_stage" and mode == "train":
+            args.max_update = len(bs_list) * args.max_epoch
+            logging.info("Max update: {}".format(args.max_update))
+
+        if args.distributed and mode=="train":
+            world_size = torch.distributed.get_world_size()
+            rank = torch.distributed.get_rank()
+            for batch in batches:
+                if len(batch) < world_size:
+                    raise RuntimeError(
+                        f"The batch-size must be equal or more than world_size: "
+                        f"{len(batch)} < {world_size}"
+                    )
+            batches = [batch[rank::world_size] for batch in batches]
+
+        if not isinstance(batches, AbsSampler):
+            self.sampler = RawSampler(batches)
+        else:
+            self.sampler = batches
+
+        self.dataset = dataset
+        self.num_iters_per_epoch = None
+        self.shuffle = mode == "train"
+        self.seed = args.seed
+        self.num_workers = args.dataset_conf.get("num_workers", 8)
+        self.collate_fn = collate_fn
+        self.pin_memory = args.ngpu > 0
+
+    def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
+        if shuffle is None:
+            shuffle = self.shuffle
+
+        if self.num_iters_per_epoch is not None:
+            N = len(self.sampler)
+            # If corpus size is larger than the num_per_epoch
+            if self.num_iters_per_epoch < N:
+                N = len(self.sampler)
+                real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
+
+                if offset >= self.num_iters_per_epoch:
+                    current_batches = self.sampler.generate(real_epoch + self.seed)
+                    if shuffle:
+                        np.random.RandomState(real_epoch + self.seed).shuffle(
+                            current_batches
+                        )
+                    batches = current_batches[
+                              offset - self.num_iters_per_epoch: offset
+                              ]
+                else:
+                    prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
+                    current_batches = self.sampler.generate(real_epoch + self.seed)
+                    if shuffle:
+                        np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
+                            prev_batches
+                        )
+                        np.random.RandomState(real_epoch + self.seed).shuffle(
+                            current_batches
+                        )
+                    batches = (
+                            prev_batches[offset - self.num_iters_per_epoch:]
+                            + current_batches[:offset]
+                    )
+
+            # If corpus size is less than the num_per_epoch
+            else:
+                _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
+                _remain = self.num_iters_per_epoch
+                batches = []
+                current_batches = self.sampler.generate(_epoch + self.seed)
+                if shuffle:
+                    np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
+                while _remain > 0:
+
+                    _batches = current_batches[_cursor: _cursor + _remain]
+                    batches += _batches
+                    if _cursor + _remain >= N:
+                        _epoch += 1
+                        _cursor = 0
+                        current_batches = self.sampler.generate(_epoch + self.seed)
+                        if shuffle:
+                            np.random.RandomState(_epoch + self.seed).shuffle(
+                                current_batches
+                            )
+                    else:
+                        _cursor = _cursor + _remain
+                    _remain -= len(_batches)
+
+                assert len(batches) == self.num_iters_per_epoch
+
+        else:
+            batches = self.sampler.generate(epoch + self.seed)
+            if shuffle:
+                np.random.RandomState(epoch + self.seed).shuffle(batches)
+
+        # For backward compatibility for pytorch DataLoader
+        if self.collate_fn is not None:
+            kwargs = dict(collate_fn=self.collate_fn)
+        else:
+            kwargs = {}
+
+        return DataLoader(
+            dataset=self.dataset,
+            batch_sampler=batches,
+            num_workers=self.num_workers,
+            pin_memory=self.pin_memory,
+            **kwargs,
+        )

--
Gitblit v1.9.1