| New file |
| | |
| | | import logging |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from torch.utils.data import DataLoader |
| | | |
| | | from funasr.datasets.small_datasets.collate_fn import CommonCollateFn |
| | | from funasr.datasets.small_datasets.dataset import ESPnetDataset |
| | | from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler |
| | | from funasr.datasets.small_datasets.preprocessor import build_preprocess |
| | | from funasr.iterators.abs_iter_factory import AbsIterFactory |
| | | from funasr.samplers.abs_sampler import AbsSampler |
| | | |
| | | |
| | | class RawSampler(AbsSampler): |
| | | def __init__(self, batches): |
| | | self.batches = batches |
| | | |
| | | def __len__(self): |
| | | return len(self.batches) |
| | | |
| | | def __iter__(self): |
| | | return iter(self.batches) |
| | | |
| | | def generate(self, seed): |
| | | return list(self.batches) |
| | | |
| | | |
| | | class SequenceIterFactory(AbsIterFactory): |
| | | """Build iterator for each epoch. |
| | | |
| | | |
| | | """ |
| | | |
| | | def __init__(self, args, mode="train"): |
| | | |
| | | # preprocess |
| | | preprocess_fn = build_preprocess(args, train=mode == "train") |
| | | |
| | | # collate |
| | | if args.task_name in ["punc", "lm"]: |
| | | collate_fn = CommonCollateFn(int_pad_value=0) |
| | | else: |
| | | collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) |
| | | |
| | | # dataset |
| | | dest_sample_rate = args.frontend_conf["fs"] if ( |
| | | args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000 |
| | | if mode == "train": |
| | | data_path_and_name_and_type = args.train_data_path_and_name_and_type |
| | | shape_files = args.train_shape_file |
| | | elif mode == "valid": |
| | | data_path_and_name_and_type = args.valid_data_path_and_name_and_type |
| | | shape_files = args.valid_shape_file |
| | | else: |
| | | raise NotImplementedError(f"mode={mode}") |
| | | dataset = ESPnetDataset( |
| | | data_path_and_name_and_type, |
| | | preprocess=preprocess_fn, |
| | | dest_sample_rate=dest_sample_rate, |
| | | ) |
| | | |
| | | # sampler |
| | | dataset_conf = args.dataset_conf |
| | | batch_sampler = LengthBatchSampler( |
| | | batch_bins=dataset_conf["batch_size"], |
| | | shape_files=shape_files, |
| | | sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending", |
| | | sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending", |
| | | drop_last=False, |
| | | padding=True, |
| | | ) |
| | | |
| | | batches = list(batch_sampler) |
| | | bs_list = [len(batch) for batch in batches] |
| | | logging.info(f"[{mode}] dataset:\n{dataset}") |
| | | logging.info(f"[{mode}] Batch sampler: {batch_sampler}") |
| | | logging.info( |
| | | f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, " |
| | | f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}" |
| | | ) |
| | | |
| | | if args.scheduler == "tri_stage" and mode == "train": |
| | | args.max_update = len(bs_list) * args.max_epoch |
| | | logging.info("Max update: {}".format(args.max_update)) |
| | | |
| | | if args.distributed: |
| | | world_size = torch.distributed.get_world_size() |
| | | rank = torch.distributed.get_rank() |
| | | for batch in batches: |
| | | if len(batch) < world_size: |
| | | raise RuntimeError( |
| | | f"The batch-size must be equal or more than world_size: " |
| | | f"{len(batch)} < {world_size}" |
| | | ) |
| | | batches = [batch[rank::world_size] for batch in batches] |
| | | |
| | | if not isinstance(batches, AbsSampler): |
| | | self.sampler = RawSampler(batches) |
| | | else: |
| | | self.sampler = batches |
| | | |
| | | self.dataset = dataset |
| | | self.num_iters_per_epoch = None |
| | | self.shuffle = mode == "train" |
| | | self.seed = args.seed |
| | | self.num_workers = args.num_workers |
| | | self.collate_fn = collate_fn |
| | | self.pin_memory = args.ngpu > 0 |
| | | |
| | | def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader: |
| | | if shuffle is None: |
| | | shuffle = self.shuffle |
| | | |
| | | if self.num_iters_per_epoch is not None: |
| | | N = len(self.sampler) |
| | | # If corpus size is larger than the num_per_epoch |
| | | if self.num_iters_per_epoch < N: |
| | | N = len(self.sampler) |
| | | real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N) |
| | | |
| | | if offset >= self.num_iters_per_epoch: |
| | | current_batches = self.sampler.generate(real_epoch + self.seed) |
| | | if shuffle: |
| | | np.random.RandomState(real_epoch + self.seed).shuffle( |
| | | current_batches |
| | | ) |
| | | batches = current_batches[ |
| | | offset - self.num_iters_per_epoch: offset |
| | | ] |
| | | else: |
| | | prev_batches = self.sampler.generate(real_epoch - 1 + self.seed) |
| | | current_batches = self.sampler.generate(real_epoch + self.seed) |
| | | if shuffle: |
| | | np.random.RandomState(real_epoch - 1 + self.seed).shuffle( |
| | | prev_batches |
| | | ) |
| | | np.random.RandomState(real_epoch + self.seed).shuffle( |
| | | current_batches |
| | | ) |
| | | batches = ( |
| | | prev_batches[offset - self.num_iters_per_epoch:] |
| | | + current_batches[:offset] |
| | | ) |
| | | |
| | | # If corpus size is less than the num_per_epoch |
| | | else: |
| | | _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N) |
| | | _remain = self.num_iters_per_epoch |
| | | batches = [] |
| | | current_batches = self.sampler.generate(_epoch + self.seed) |
| | | if shuffle: |
| | | np.random.RandomState(_epoch + self.seed).shuffle(current_batches) |
| | | while _remain > 0: |
| | | |
| | | _batches = current_batches[_cursor: _cursor + _remain] |
| | | batches += _batches |
| | | if _cursor + _remain >= N: |
| | | _epoch += 1 |
| | | _cursor = 0 |
| | | current_batches = self.sampler.generate(_epoch + self.seed) |
| | | if shuffle: |
| | | np.random.RandomState(_epoch + self.seed).shuffle( |
| | | current_batches |
| | | ) |
| | | else: |
| | | _cursor = _cursor + _remain |
| | | _remain -= len(_batches) |
| | | |
| | | assert len(batches) == self.num_iters_per_epoch |
| | | |
| | | else: |
| | | batches = self.sampler.generate(epoch + self.seed) |
| | | if shuffle: |
| | | np.random.RandomState(epoch + self.seed).shuffle(batches) |
| | | |
| | | # For backward compatibility for pytorch DataLoader |
| | | if self.collate_fn is not None: |
| | | kwargs = dict(collate_fn=self.collate_fn) |
| | | else: |
| | | kwargs = {} |
| | | |
| | | return DataLoader( |
| | | dataset=self.dataset, |
| | | batch_sampler=batches, |
| | | num_workers=self.num_workers, |
| | | pin_memory=self.pin_memory, |
| | | **kwargs, |
| | | ) |