zhifu gao
2023-05-18 97a689d65da434345a641a909f13b78e5690c86b
funasr/datasets/small_datasets/sequence_iter_factory.py
New file
@@ -0,0 +1,189 @@
import logging
import numpy as np
import torch
from torch.utils.data import DataLoader
from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
from funasr.datasets.small_datasets.dataset import ESPnetDataset
from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
from funasr.datasets.small_datasets.preprocessor import build_preprocess
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.samplers.abs_sampler import AbsSampler
class RawSampler(AbsSampler):
    def __init__(self, batches):
        self.batches = batches
    def __len__(self):
        return len(self.batches)
    def __iter__(self):
        return iter(self.batches)
    def generate(self, seed):
        return list(self.batches)
class SequenceIterFactory(AbsIterFactory):
    """Build iterator for each epoch, modified from ESPnet
    """
    def __init__(self, args, mode="train"):
        # preprocess
        preprocess_fn = build_preprocess(args, train=mode == "train")
        # collate
        if args.task_name in ["punc", "lm"]:
            collate_fn = CommonCollateFn(int_pad_value=0)
        else:
            collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
        # dataset
        dest_sample_rate = args.frontend_conf["fs"] if (
                args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
        if mode == "train":
            data_path_and_name_and_type = args.train_data_path_and_name_and_type
            shape_files = args.train_shape_file
        elif mode == "valid":
            data_path_and_name_and_type = args.valid_data_path_and_name_and_type
            shape_files = args.valid_shape_file
        else:
            raise NotImplementedError(f"mode={mode}")
        dataset = ESPnetDataset(
            data_path_and_name_and_type,
            preprocess=preprocess_fn,
            dest_sample_rate=dest_sample_rate,
            speed_perturb=args.speed_perturb if mode=="train" else None,
        )
        # sampler
        dataset_conf = args.dataset_conf
        batch_sampler = LengthBatchSampler(
            batch_bins=dataset_conf["batch_conf"]["batch_size"] * args.ngpu,
            shape_files=shape_files,
            sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending",
            sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending",
            drop_last=False,
            padding=True,
        )
        batches = list(batch_sampler)
        bs_list = [len(batch) for batch in batches]
        logging.info(f"[{mode}] dataset:\n{dataset}")
        logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
        logging.info(
            f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
            f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
        )
        if args.scheduler == "tri_stage" and mode == "train":
            args.max_update = len(bs_list) * args.max_epoch
            logging.info("Max update: {}".format(args.max_update))
        if args.distributed and mode=="train":
            world_size = torch.distributed.get_world_size()
            rank = torch.distributed.get_rank()
            for batch in batches:
                if len(batch) < world_size:
                    raise RuntimeError(
                        f"The batch-size must be equal or more than world_size: "
                        f"{len(batch)} < {world_size}"
                    )
            batches = [batch[rank::world_size] for batch in batches]
        if not isinstance(batches, AbsSampler):
            self.sampler = RawSampler(batches)
        else:
            self.sampler = batches
        self.dataset = dataset
        self.num_iters_per_epoch = None
        self.shuffle = mode == "train"
        self.seed = args.seed
        self.num_workers = args.dataset_conf.get("num_workers", 8)
        self.collate_fn = collate_fn
        self.pin_memory = args.ngpu > 0
    def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
        if shuffle is None:
            shuffle = self.shuffle
        if self.num_iters_per_epoch is not None:
            N = len(self.sampler)
            # If corpus size is larger than the num_per_epoch
            if self.num_iters_per_epoch < N:
                N = len(self.sampler)
                real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
                if offset >= self.num_iters_per_epoch:
                    current_batches = self.sampler.generate(real_epoch + self.seed)
                    if shuffle:
                        np.random.RandomState(real_epoch + self.seed).shuffle(
                            current_batches
                        )
                    batches = current_batches[
                              offset - self.num_iters_per_epoch: offset
                              ]
                else:
                    prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
                    current_batches = self.sampler.generate(real_epoch + self.seed)
                    if shuffle:
                        np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
                            prev_batches
                        )
                        np.random.RandomState(real_epoch + self.seed).shuffle(
                            current_batches
                        )
                    batches = (
                            prev_batches[offset - self.num_iters_per_epoch:]
                            + current_batches[:offset]
                    )
            # If corpus size is less than the num_per_epoch
            else:
                _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
                _remain = self.num_iters_per_epoch
                batches = []
                current_batches = self.sampler.generate(_epoch + self.seed)
                if shuffle:
                    np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
                while _remain > 0:
                    _batches = current_batches[_cursor: _cursor + _remain]
                    batches += _batches
                    if _cursor + _remain >= N:
                        _epoch += 1
                        _cursor = 0
                        current_batches = self.sampler.generate(_epoch + self.seed)
                        if shuffle:
                            np.random.RandomState(_epoch + self.seed).shuffle(
                                current_batches
                            )
                    else:
                        _cursor = _cursor + _remain
                    _remain -= len(_batches)
                assert len(batches) == self.num_iters_per_epoch
        else:
            batches = self.sampler.generate(epoch + self.seed)
            if shuffle:
                np.random.RandomState(epoch + self.seed).shuffle(batches)
        # For backward compatibility for pytorch DataLoader
        if self.collate_fn is not None:
            kwargs = dict(collate_fn=self.collate_fn)
        else:
            kwargs = {}
        return DataLoader(
            dataset=self.dataset,
            batch_sampler=batches,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            **kwargs,
        )