雾聪
2024-01-08 2acef4bdaea588adee3098a057a395937dff4e6a
funasr/datasets/dataset.py
@@ -16,15 +16,15 @@
from typing import Mapping
from typing import Tuple
from typing import Union
import h5py
try:
    import h5py
except:
    print("If you want use h5py dataset, please pip install h5py, and try it again")
import humanfriendly
import kaldiio
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.npy_scp import NpyScpReader
from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
@@ -37,7 +37,6 @@
class AdapterForSoundScpReader(collections.abc.Mapping):
    def __init__(self, loader, dtype=None):
        assert check_argument_types()
        self.loader = loader
        self.dtype = dtype
        self.rate = None
@@ -107,7 +106,7 @@
        return value[()]
def sound_loader(path, float_dtype=None):
def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
    # The file is as follows:
    #   utterance_id_A /some/where/a.wav
    #   utterance_id_B /some/where/a.flac
@@ -115,7 +114,7 @@
    # NOTE(kamo): SoundScpReader doesn't support pipe-fashion
    # like Kaldi e.g. "cat a.wav |".
    # NOTE(kamo): The audio signal is normalized to [-1,1] range.
    loader = SoundScpReader(path, normalize=True, always_2d=False)
    loader = SoundScpReader(path, normalize=True, always_2d=False, dest_sample_rate = dest_sample_rate)
    # SoundScpReader.__getitem__() returns Tuple[int, ndarray],
    # but ndarray is desired, so Adapter class is inserted here
@@ -139,7 +138,7 @@
DATA_TYPES = {
    "sound": dict(
        func=sound_loader,
        kwargs=["float_dtype"],
        kwargs=["dest_sample_rate","float_dtype"],
        help="Audio format types which supported by sndfile wav, flac, etc."
        "\n\n"
        "   utterance_id_a a.wav\n"
@@ -282,8 +281,8 @@
        int_dtype: str = "long",
        max_cache_size: Union[float, int, str] = 0.0,
        max_cache_fd: int = 0,
        dest_sample_rate: int = 16000,
    ):
        assert check_argument_types()
        if len(path_name_type_list) == 0:
            raise ValueError(
                '1 or more elements are required for "path_name_type_list"'
@@ -295,6 +294,7 @@
        self.float_dtype = float_dtype
        self.int_dtype = int_dtype
        self.max_cache_fd = max_cache_fd
        self.dest_sample_rate = dest_sample_rate
        self.loader_dict = {}
        self.debug_info = {}
@@ -335,6 +335,8 @@
                for key2 in dic["kwargs"]:
                    if key2 == "loader_type":
                        kwargs["loader_type"] = loader_type
                    elif key2 == "dest_sample_rate" and loader_type=="sound":
                        kwargs["dest_sample_rate"] = self.dest_sample_rate
                    elif key2 == "float_dtype":
                        kwargs["float_dtype"] = self.float_dtype
                    elif key2 == "int_dtype":
@@ -375,7 +377,6 @@
        return _mes
    def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
        assert check_argument_types()
        # Change integer-id to string-id
        if isinstance(uid, int):
@@ -440,5 +441,4 @@
            self.cache[uid] = data
        retval = uid, data
        assert check_return_type(retval)
        return retval