游雁
2023-10-10 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3
funasr/datasets/small_datasets/dataset.py
@@ -9,25 +9,19 @@
from typing import Collection
from typing import Dict
from typing import Mapping
from typing import Tuple
from typing import Union
from typing import Union, List, Tuple
import humanfriendly
import kaldiio
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.npy_scp import NpyScpReader
from funasr.fileio.sound_scp import SoundScpReader
from funasr.utils.sized_dict import SizedDict
class AdapterForSoundScpReader(collections.abc.Mapping):
    def __init__(self, loader, dtype=None):
        assert check_argument_types()
        self.loader = loader
        self.dtype = dtype
        self.rate = None
@@ -111,11 +105,10 @@
            ] = None,
            float_dtype: str = "float32",
            int_dtype: str = "long",
            max_cache_size: Union[float, int, str] = 0.0,
            max_cache_fd: int = 0,
            dest_sample_rate: int = 16000,
            speed_perturb: Union[list, tuple] = None,
            mode: str = "train",
    ):
        assert check_argument_types()
        if len(path_name_type_list) == 0:
            raise ValueError(
                '1 or more elements are required for "path_name_type_list"'
@@ -126,8 +119,11 @@
        self.float_dtype = float_dtype
        self.int_dtype = int_dtype
        self.max_cache_fd = max_cache_fd
        self.dest_sample_rate = dest_sample_rate
        self.speed_perturb = speed_perturb
        self.mode = mode
        if self.speed_perturb is not None:
            logging.info("Using speed_perturb: {}".format(speed_perturb))
        self.loader_dict = {}
        self.debug_info = {}
@@ -141,17 +137,9 @@
            if len(self.loader_dict[name]) == 0:
                raise RuntimeError(f"{path} has no samples")
        if isinstance(max_cache_size, str):
            max_cache_size = humanfriendly.parse_size(max_cache_size)
        self.max_cache_size = max_cache_size
        if max_cache_size > 0:
            self.cache = SizedDict(shared=True)
        else:
            self.cache = None
    def _build_loader(
            self, path: str, loader_type: str
    ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
    ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, List[int], numbers.Number]]:
        """Helper function to instantiate Loader.
        Args:
@@ -159,13 +147,15 @@
            loader_type:  loader_type. sound, npy, text, etc
        """
        if loader_type == "sound":
            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False)
            speed_perturb = self.speed_perturb if self.mode == "train" else None
            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False,
                                    speed_perturb=speed_perturb)
            return AdapterForSoundScpReader(loader, self.float_dtype)
        elif loader_type == "kaldi_ark":
            loader = kaldiio.load_scp(path, max_cache_fd=self.max_cache_fd)
            loader = kaldiio.load_scp(path)
            return AdapterForSoundScpReader(loader, self.float_dtype)
        elif loader_type == "npy":
            return NpyScpReader()
            return NpyScpReader(path)
        elif loader_type == "text":
            text_loader = {}
            with open(path, "r", encoding="utf-8") as f:
@@ -179,6 +169,19 @@
                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
                    text_loader[k] = v
            return text_loader
        elif loader_type == "text_int":
            text_int_loader = {}
            with open(path, "r", encoding="utf-8") as f:
                for linenum, line in enumerate(f, 1):
                    sps = line.rstrip().split(maxsplit=1)
                    if len(sps) == 1:
                        k, v = sps[0], ""
                    else:
                        k, v = sps
                    if k in text_int_loader:
                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
                    text_int_loader[k] = [int(i) for i in v.split()]
            return text_int_loader
        else:
            raise RuntimeError(f"Not supported: loader_type={loader_type}")
@@ -200,16 +203,11 @@
        return _mes
    def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
        assert check_argument_types()
        # Change integer-id to string-id
        if isinstance(uid, int):
            d = next(iter(self.loader_dict.values()))
            uid = list(d)[uid]
        if self.cache is not None and uid in self.cache:
            data = self.cache[uid]
            return uid, data
        data = {}
        # 1. Load data from each loaders
@@ -261,9 +259,5 @@
                raise NotImplementedError(f"Not supported dtype: {value.dtype}")
            data[name] = value
        if self.cache is not None and self.cache.size < self.max_cache_size:
            self.cache[uid] = data
        retval = uid, data
        assert check_return_type(retval)
        return retval