游雁
2023-05-30 d2314e8de6ef3b615b640e407432ea4e42a6e51d
funasr/datasets/small_datasets/dataset.py
@@ -9,10 +9,8 @@
from typing import Collection
from typing import Dict
from typing import Mapping
from typing import Tuple
from typing import Union
from typing import Union, List, Tuple
import humanfriendly
import kaldiio
import numpy as np
import torch
@@ -22,7 +20,6 @@
from funasr.fileio.npy_scp import NpyScpReader
from funasr.fileio.sound_scp import SoundScpReader
from funasr.utils.sized_dict import SizedDict
class AdapterForSoundScpReader(collections.abc.Mapping):
@@ -111,9 +108,9 @@
            ] = None,
            float_dtype: str = "float32",
            int_dtype: str = "long",
            max_cache_size: Union[float, int, str] = 0.0,
            max_cache_fd: int = 0,
            dest_sample_rate: int = 16000,
            speed_perturb: Union[list, tuple] = None,
            mode: str = "train",
    ):
        assert check_argument_types()
        if len(path_name_type_list) == 0:
@@ -126,8 +123,11 @@
        self.float_dtype = float_dtype
        self.int_dtype = int_dtype
        self.max_cache_fd = max_cache_fd
        self.dest_sample_rate = dest_sample_rate
        self.speed_perturb = speed_perturb
        self.mode = mode
        if self.speed_perturb is not None:
            logging.info("Using speed_perturb: {}".format(speed_perturb))
        self.loader_dict = {}
        self.debug_info = {}
@@ -141,17 +141,9 @@
            if len(self.loader_dict[name]) == 0:
                raise RuntimeError(f"{path} has no samples")
        if isinstance(max_cache_size, str):
            max_cache_size = humanfriendly.parse_size(max_cache_size)
        self.max_cache_size = max_cache_size
        if max_cache_size > 0:
            self.cache = SizedDict(shared=True)
        else:
            self.cache = None
    def _build_loader(
            self, path: str, loader_type: str
    ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
    ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, List[int], numbers.Number]]:
        """Helper function to instantiate Loader.
        Args:
@@ -159,13 +151,15 @@
            loader_type:  loader_type. sound, npy, text, etc
        """
        if loader_type == "sound":
            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False)
            speed_perturb = self.speed_perturb if self.mode == "train" else None
            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False,
                                    speed_perturb=speed_perturb)
            return AdapterForSoundScpReader(loader, self.float_dtype)
        elif loader_type == "kaldi_ark":
            loader = kaldiio.load_scp(path, max_cache_fd=self.max_cache_fd)
            loader = kaldiio.load_scp(path)
            return AdapterForSoundScpReader(loader, self.float_dtype)
        elif loader_type == "npy":
            return NpyScpReader()
            return NpyScpReader(path)
        elif loader_type == "text":
            text_loader = {}
            with open(path, "r", encoding="utf-8") as f:
@@ -179,6 +173,19 @@
                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
                    text_loader[k] = v
            return text_loader
        elif loader_type == "text_int":
            text_int_loader = {}
            with open(path, "r", encoding="utf-8") as f:
                for linenum, line in enumerate(f, 1):
                    sps = line.rstrip().split(maxsplit=1)
                    if len(sps) == 1:
                        k, v = sps[0], ""
                    else:
                        k, v = sps
                    if k in text_int_loader:
                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
                    text_int_loader[k] = [int(i) for i in v.split()]
            return text_int_loader
        else:
            raise RuntimeError(f"Not supported: loader_type={loader_type}")
@@ -206,10 +213,6 @@
        if isinstance(uid, int):
            d = next(iter(self.loader_dict.values()))
            uid = list(d)[uid]
        if self.cache is not None and uid in self.cache:
            data = self.cache[uid]
            return uid, data
        data = {}
        # 1. Load data from each loaders
@@ -260,9 +263,6 @@
            else:
                raise NotImplementedError(f"Not supported dtype: {value.dtype}")
            data[name] = value
        if self.cache is not None and self.cache.size < self.max_cache_size:
            self.cache[uid] = data
        retval = uid, data
        assert check_return_type(retval)