hnluo
2023-06-05 594b79f59e7eefa6955c729f6264c8c99d1d9571
funasr/datasets/small_datasets/dataset.py
@@ -9,8 +9,7 @@
from typing import Collection
from typing import Dict
from typing import Mapping
from typing import Tuple
from typing import Union
from typing import Union, List, Tuple
import kaldiio
import numpy as np
@@ -110,6 +109,8 @@
            float_dtype: str = "float32",
            int_dtype: str = "long",
            dest_sample_rate: int = 16000,
            speed_perturb: Union[list, tuple] = None,
            mode: str = "train",
    ):
        assert check_argument_types()
        if len(path_name_type_list) == 0:
@@ -123,6 +124,10 @@
        self.float_dtype = float_dtype
        self.int_dtype = int_dtype
        self.dest_sample_rate = dest_sample_rate
        self.speed_perturb = speed_perturb
        self.mode = mode
        if self.speed_perturb is not None:
            logging.info("Using speed_perturb: {}".format(speed_perturb))
        self.loader_dict = {}
        self.debug_info = {}
@@ -138,7 +143,7 @@
    def _build_loader(
            self, path: str, loader_type: str
    ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
    ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, List[int], numbers.Number]]:
        """Helper function to instantiate Loader.
        Args:
@@ -146,13 +151,15 @@
            loader_type:  loader_type. sound, npy, text, etc
        """
        if loader_type == "sound":
            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False)
            speed_perturb = self.speed_perturb if self.mode == "train" else None
            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False,
                                    speed_perturb=speed_perturb)
            return AdapterForSoundScpReader(loader, self.float_dtype)
        elif loader_type == "kaldi_ark":
            loader = kaldiio.load_scp(path)
            return AdapterForSoundScpReader(loader, self.float_dtype)
        elif loader_type == "npy":
            return NpyScpReader()
            return NpyScpReader(path)
        elif loader_type == "text":
            text_loader = {}
            with open(path, "r", encoding="utf-8") as f:
@@ -166,6 +173,19 @@
                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
                    text_loader[k] = v
            return text_loader
        elif loader_type == "text_int":
            text_int_loader = {}
            with open(path, "r", encoding="utf-8") as f:
                for linenum, line in enumerate(f, 1):
                    sps = line.rstrip().split(maxsplit=1)
                    if len(sps) == 1:
                        k, v = sps[0], ""
                    else:
                        k, v = sps
                    if k in text_int_loader:
                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
                    text_int_loader[k] = [int(i) for i in v.split()]
            return text_int_loader
        else:
            raise RuntimeError(f"Not supported: loader_type={loader_type}")