speech_asr
2023-04-18 05d4176e88bdde0ce615cb22daf7af725a496020
update
1个文件已修改
2个文件已添加
1 文件已重命名
313 ■■■■■ 已修改文件
funasr/datasets/small_datasets/build_dataloader.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/small_datasets/collate_fn.py 93 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/small_datasets/sequence_iter_factory.py 189 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/build_dataloader.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/small_datasets/build_dataloader.py
File was renamed from funasr/datasets/small_datasets/build_loader.py
@@ -1,16 +1,26 @@
import logging
import os
import numpy as np
import torch
from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
from funasr.datasets.small_datasets.dataset import ESPnetDataset
from funasr.datasets.small_datasets.preprocessor import build_preprocess
from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
from funasr.datasets.small_datasets.preprocessor import build_preprocess
from funasr.datasets.small_datasets.sequence_iter_factory import SequenceIterFactory
def build_dataloader(args, mode="train"):
    # preprocess
    preprocess_fn = build_preprocess(args, train=mode == "train")
    # collate
    if args.task_name in ["punc", "lm"]:
        collate_fn = CommonCollateFn(int_pad_value=0)
    else:
        collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
    # dataset
    dest_sample_rate = args.frontend_conf["fs"] if (
            args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
    if mode == "train":
@@ -27,6 +37,7 @@
        dest_sample_rate=dest_sample_rate,
    )
    # sampler
    dataset_conf = args.dataset_conf
    batch_sampler = LengthBatchSampler(
        batch_bins=dataset_conf["batch_size"],
@@ -60,3 +71,14 @@
                    f"{len(batch)} < {world_size}"
                )
        batches = [batch[rank::world_size] for batch in batches]
    # dataloader
    return SequenceIterFactory(
        dataset=dataset,
        batches=batches,
        seed=args.seed,
        shuffle=mode == "train",
        num_workers=args.num_workers,
        collate_fn=collate_fn,
        pin_memory=args.ngpu > 0,
    )
funasr/datasets/small_datasets/collate_fn.py
New file
@@ -0,0 +1,93 @@
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.modules.nets_utils import pad_list
class CommonCollateFn:
    """Functor class of common_collate_fn()"""
    def __init__(
            self,
            float_pad_value: Union[float, int] = 0.0,
            int_pad_value: int = -32768,
            not_sequence: Collection[str] = (),
            max_sample_size=None
    ):
        assert check_argument_types()
        self.float_pad_value = float_pad_value
        self.int_pad_value = int_pad_value
        self.not_sequence = set(not_sequence)
        self.max_sample_size = max_sample_size
    def __repr__(self):
        return (
            f"{self.__class__}(float_pad_value={self.float_pad_value}, "
            f"int_pad_value={self.float_pad_value})"
        )
    def __call__(
            self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
    ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
        return common_collate_fn(
            data,
            float_pad_value=self.float_pad_value,
            int_pad_value=self.int_pad_value,
            not_sequence=self.not_sequence,
        )
def common_collate_fn(
        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
        float_pad_value: Union[float, int] = 0.0,
        int_pad_value: int = -32768,
        not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
    """Concatenate ndarray-list to an array and convert to torch.Tensor.
    """
    assert check_argument_types()
    uttids = [u for u, _ in data]
    data = [d for _, d in data]
    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
    assert all(
        not k.endswith("_lengths") for k in data[0]
    ), f"*_lengths is reserved: {list(data[0])}"
    output = {}
    for key in data[0]:
        if data[0][key].dtype.kind == "i":
            pad_value = int_pad_value
        else:
            pad_value = float_pad_value
        array_list = [d[key] for d in data]
        tensor_list = [torch.from_numpy(a) for a in array_list]
        tensor = pad_list(tensor_list, pad_value)
        output[key] = tensor
        if key not in not_sequence:
            lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
            output[key + "_lengths"] = lens
    output = (uttids, output)
    assert check_return_type(output)
    return output
def crop_to_max_size(feature, target_size):
    size = len(feature)
    diff = size - target_size
    if diff <= 0:
        return feature
    start = np.random.randint(0, diff + 1)
    end = size - diff + start
    return feature[start:end]
funasr/datasets/small_datasets/sequence_iter_factory.py
New file
@@ -0,0 +1,189 @@
import logging
import numpy as np
import torch
from torch.utils.data import DataLoader
from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
from funasr.datasets.small_datasets.dataset import ESPnetDataset
from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
from funasr.datasets.small_datasets.preprocessor import build_preprocess
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.samplers.abs_sampler import AbsSampler
class RawSampler(AbsSampler):
    def __init__(self, batches):
        self.batches = batches
    def __len__(self):
        return len(self.batches)
    def __iter__(self):
        return iter(self.batches)
    def generate(self, seed):
        return list(self.batches)
class SequenceIterFactory(AbsIterFactory):
    """Build iterator for each epoch.
    """
    def __init__(self, args, mode="train"):
        # preprocess
        preprocess_fn = build_preprocess(args, train=mode == "train")
        # collate
        if args.task_name in ["punc", "lm"]:
            collate_fn = CommonCollateFn(int_pad_value=0)
        else:
            collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
        # dataset
        dest_sample_rate = args.frontend_conf["fs"] if (
                args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
        if mode == "train":
            data_path_and_name_and_type = args.train_data_path_and_name_and_type
            shape_files = args.train_shape_file
        elif mode == "valid":
            data_path_and_name_and_type = args.valid_data_path_and_name_and_type
            shape_files = args.valid_shape_file
        else:
            raise NotImplementedError(f"mode={mode}")
        dataset = ESPnetDataset(
            data_path_and_name_and_type,
            preprocess=preprocess_fn,
            dest_sample_rate=dest_sample_rate,
        )
        # sampler
        dataset_conf = args.dataset_conf
        batch_sampler = LengthBatchSampler(
            batch_bins=dataset_conf["batch_size"],
            shape_files=shape_files,
            sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending",
            sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending",
            drop_last=False,
            padding=True,
        )
        batches = list(batch_sampler)
        bs_list = [len(batch) for batch in batches]
        logging.info(f"[{mode}] dataset:\n{dataset}")
        logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
        logging.info(
            f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
            f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
        )
        if args.scheduler == "tri_stage" and mode == "train":
            args.max_update = len(bs_list) * args.max_epoch
            logging.info("Max update: {}".format(args.max_update))
        if args.distributed:
            world_size = torch.distributed.get_world_size()
            rank = torch.distributed.get_rank()
            for batch in batches:
                if len(batch) < world_size:
                    raise RuntimeError(
                        f"The batch-size must be equal or more than world_size: "
                        f"{len(batch)} < {world_size}"
                    )
            batches = [batch[rank::world_size] for batch in batches]
        if not isinstance(batches, AbsSampler):
            self.sampler = RawSampler(batches)
        else:
            self.sampler = batches
        self.dataset = dataset
        self.num_iters_per_epoch = None
        self.shuffle = mode == "train"
        self.seed = args.seed
        self.num_workers = args.num_workers
        self.collate_fn = collate_fn
        self.pin_memory = args.ngpu > 0
    def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
        if shuffle is None:
            shuffle = self.shuffle
        if self.num_iters_per_epoch is not None:
            N = len(self.sampler)
            # If corpus size is larger than the num_per_epoch
            if self.num_iters_per_epoch < N:
                N = len(self.sampler)
                real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
                if offset >= self.num_iters_per_epoch:
                    current_batches = self.sampler.generate(real_epoch + self.seed)
                    if shuffle:
                        np.random.RandomState(real_epoch + self.seed).shuffle(
                            current_batches
                        )
                    batches = current_batches[
                              offset - self.num_iters_per_epoch: offset
                              ]
                else:
                    prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
                    current_batches = self.sampler.generate(real_epoch + self.seed)
                    if shuffle:
                        np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
                            prev_batches
                        )
                        np.random.RandomState(real_epoch + self.seed).shuffle(
                            current_batches
                        )
                    batches = (
                            prev_batches[offset - self.num_iters_per_epoch:]
                            + current_batches[:offset]
                    )
            # If corpus size is less than the num_per_epoch
            else:
                _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
                _remain = self.num_iters_per_epoch
                batches = []
                current_batches = self.sampler.generate(_epoch + self.seed)
                if shuffle:
                    np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
                while _remain > 0:
                    _batches = current_batches[_cursor: _cursor + _remain]
                    batches += _batches
                    if _cursor + _remain >= N:
                        _epoch += 1
                        _cursor = 0
                        current_batches = self.sampler.generate(_epoch + self.seed)
                        if shuffle:
                            np.random.RandomState(_epoch + self.seed).shuffle(
                                current_batches
                            )
                    else:
                        _cursor = _cursor + _remain
                    _remain -= len(_batches)
                assert len(batches) == self.num_iters_per_epoch
        else:
            batches = self.sampler.generate(epoch + self.seed)
            if shuffle:
                np.random.RandomState(epoch + self.seed).shuffle(batches)
        # For backward compatibility for pytorch DataLoader
        if self.collate_fn is not None:
            kwargs = dict(collate_fn=self.collate_fn)
        else:
            kwargs = {}
        return DataLoader(
            dataset=self.dataset,
            batch_sampler=batches,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            **kwargs,
        )
funasr/utils/build_dataloader.py
@@ -1,9 +1,10 @@
from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
from funasr.datasets.small_datasets.build_dataloader import build_dataloader
def build_dataloader(args):
    if args.dataset_type == "small":
        pass
        train_iter_factory = LargeDataLoader(args, mode="train")
        valid_iter_factory = LargeDataLoader(args, mode="valid")
    elif args.dataset_type == "large":
        train_iter_factory = LargeDataLoader(args, mode="train")
        valid_iter_factory = LargeDataLoader(args,  mode="valid")