嘉渊
2023-06-19 289cb1d2c8d2fc5a54e9b0fb07b2c33800408d42
funasr/datasets/small_datasets/dataset.py
@@ -9,8 +9,7 @@
from typing import Collection
from typing import Dict
from typing import Mapping
from typing import Tuple
from typing import Union
from typing import Union, List, Tuple
import kaldiio
import numpy as np
@@ -110,7 +109,7 @@
            float_dtype: str = "float32",
            int_dtype: str = "long",
            dest_sample_rate: int = 16000,
            speed_perturb: tuple = None,
            speed_perturb: Union[list, tuple] = None,
            mode: str = "train",
    ):
        assert check_argument_types()
@@ -127,6 +126,8 @@
        self.dest_sample_rate = dest_sample_rate
        self.speed_perturb = speed_perturb
        self.mode = mode
        if self.speed_perturb is not None:
            logging.info("Using speed_perturb: {}".format(speed_perturb))
        self.loader_dict = {}
        self.debug_info = {}
@@ -142,7 +143,7 @@
    def _build_loader(
            self, path: str, loader_type: str
    ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
    ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, List[int], numbers.Number]]:
        """Helper function to instantiate Loader.
        Args:
@@ -151,7 +152,8 @@
        """
        if loader_type == "sound":
            speed_perturb = self.speed_perturb if self.mode == "train" else None
            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False, speed_perturb=speed_perturb)
            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False,
                                    speed_perturb=speed_perturb)
            return AdapterForSoundScpReader(loader, self.float_dtype)
        elif loader_type == "kaldi_ark":
            loader = kaldiio.load_scp(path)
@@ -171,6 +173,19 @@
                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
                    text_loader[k] = v
            return text_loader
        elif loader_type == "text_int":
            text_int_loader = {}
            with open(path, "r", encoding="utf-8") as f:
                for linenum, line in enumerate(f, 1):
                    sps = line.rstrip().split(maxsplit=1)
                    if len(sps) == 1:
                        k, v = sps[0], ""
                    else:
                        k, v = sps
                    if k in text_int_loader:
                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
                    text_int_loader[k] = [int(i) for i in v.split()]
            return text_int_loader
        else:
            raise RuntimeError(f"Not supported: loader_type={loader_type}")