From 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 10 十月 2023 17:17:29 +0800
Subject: [PATCH] v0.8.0
---
funasr/datasets/small_datasets/dataset.py | 289 +++++++++++----------------------------------------------
1 files changed, 55 insertions(+), 234 deletions(-)
diff --git a/funasr/datasets/small_datasets/dataset.py b/funasr/datasets/small_datasets/dataset.py
index 7ed37fa..bee9f50 100644
--- a/funasr/datasets/small_datasets/dataset.py
+++ b/funasr/datasets/small_datasets/dataset.py
@@ -1,43 +1,27 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-from abc import ABC
-from abc import abstractmethod
import collections
import copy
-import functools
import logging
import numbers
-import re
-from typing import Any
from typing import Callable
from typing import Collection
from typing import Dict
from typing import Mapping
-from typing import Tuple
-from typing import Union
+from typing import Union, List, Tuple
-import h5py
-import humanfriendly
import kaldiio
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
-from typeguard import check_argument_types
-from typeguard import check_return_type
from funasr.fileio.npy_scp import NpyScpReader
-from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
-from funasr.fileio.rand_gen_dataset import IntRandomGenerateDataset
-from funasr.fileio.read_text import load_num_sequence_text
-from funasr.fileio.read_text import read_2column_text
from funasr.fileio.sound_scp import SoundScpReader
-from funasr.utils.sized_dict import SizedDict
class AdapterForSoundScpReader(collections.abc.Mapping):
def __init__(self, loader, dtype=None):
- assert check_argument_types()
self.loader = loader
self.dtype = dtype
self.rate = None
@@ -88,25 +72,6 @@
return array
-class H5FileWrapper:
- def __init__(self, path: str):
- self.path = path
- self.h5_file = h5py.File(path, "r")
-
- def __repr__(self) -> str:
- return str(self.h5_file)
-
- def __len__(self) -> int:
- return len(self.h5_file)
-
- def __iter__(self):
- return iter(self.h5_file)
-
- def __getitem__(self, key) -> np.ndarray:
- value = self.h5_file[key]
- return value[()]
-
-
def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
# The file is as follows:
# utterance_id_A /some/where/a.wav
@@ -127,158 +92,23 @@
return AdapterForSoundScpReader(loader, float_dtype)
-def rand_int_loader(filepath, loader_type):
- # e.g. rand_int_3_10
- try:
- low, high = map(int, loader_type[len("rand_int_") :].split("_"))
- except ValueError:
- raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
- return IntRandomGenerateDataset(filepath, low, high)
-
-
-DATA_TYPES = {
- "sound": dict(
- func=sound_loader,
- kwargs=["dest_sample_rate","float_dtype"],
- help="Audio format types which supported by sndfile wav, flac, etc."
- "\n\n"
- " utterance_id_a a.wav\n"
- " utterance_id_b b.wav\n"
- " ...",
- ),
- "kaldi_ark": dict(
- func=kaldi_loader,
- kwargs=["max_cache_fd"],
- help="Kaldi-ark file type."
- "\n\n"
- " utterance_id_A /some/where/a.ark:123\n"
- " utterance_id_B /some/where/a.ark:456\n"
- " ...",
- ),
- "npy": dict(
- func=NpyScpReader,
- kwargs=[],
- help="Npy file format."
- "\n\n"
- " utterance_id_A /some/where/a.npy\n"
- " utterance_id_B /some/where/b.npy\n"
- " ...",
- ),
- "text_int": dict(
- func=functools.partial(load_num_sequence_text, loader_type="text_int"),
- kwargs=[],
- help="A text file in which is written a sequence of interger numbers "
- "separated by space."
- "\n\n"
- " utterance_id_A 12 0 1 3\n"
- " utterance_id_B 3 3 1\n"
- " ...",
- ),
- "csv_int": dict(
- func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
- kwargs=[],
- help="A text file in which is written a sequence of interger numbers "
- "separated by comma."
- "\n\n"
- " utterance_id_A 100,80\n"
- " utterance_id_B 143,80\n"
- " ...",
- ),
- "text_float": dict(
- func=functools.partial(load_num_sequence_text, loader_type="text_float"),
- kwargs=[],
- help="A text file in which is written a sequence of float numbers "
- "separated by space."
- "\n\n"
- " utterance_id_A 12. 3.1 3.4 4.4\n"
- " utterance_id_B 3. 3.12 1.1\n"
- " ...",
- ),
- "csv_float": dict(
- func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
- kwargs=[],
- help="A text file in which is written a sequence of float numbers "
- "separated by comma."
- "\n\n"
- " utterance_id_A 12.,3.1,3.4,4.4\n"
- " utterance_id_B 3.,3.12,1.1\n"
- " ...",
- ),
- "text": dict(
- func=read_2column_text,
- kwargs=[],
- help="Return text as is. The text must be converted to ndarray "
- "by 'preprocess'."
- "\n\n"
- " utterance_id_A hello world\n"
- " utterance_id_B foo bar\n"
- " ...",
- ),
- "hdf5": dict(
- func=H5FileWrapper,
- kwargs=[],
- help="A HDF5 file which contains arrays at the first level or the second level."
- " >>> f = h5py.File('file.h5')\n"
- " >>> array1 = f['utterance_id_A']\n"
- " >>> array2 = f['utterance_id_B']\n",
- ),
- "rand_float": dict(
- func=FloatRandomGenerateDataset,
- kwargs=[],
- help="Generate random float-ndarray which has the given shapes "
- "in the file."
- "\n\n"
- " utterance_id_A 3,4\n"
- " utterance_id_B 10,4\n"
- " ...",
- ),
- "rand_int_\\d+_\\d+": dict(
- func=rand_int_loader,
- kwargs=["loader_type"],
- help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
- "shapes in the path. "
- "Give the lower and upper value by the file type. e.g. "
- "rand_int_0_10 -> Generate integers from 0 to 10."
- "\n\n"
- " utterance_id_A 3,4\n"
- " utterance_id_B 10,4\n"
- " ...",
- ),
-}
-
-
-class AbsDataset(Dataset, ABC):
- @abstractmethod
- def has_name(self, name) -> bool:
- raise NotImplementedError
-
- @abstractmethod
- def names(self) -> Tuple[str, ...]:
- raise NotImplementedError
-
- @abstractmethod
- def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
- raise NotImplementedError
-
-
-class ESPnetDataset(AbsDataset):
+class ESPnetDataset(Dataset):
"""
- Pytorch Dataset class for FunASR, simplied from ESPnet
+ Pytorch Dataset class for FunASR, modified from ESPnet
"""
def __init__(
- self,
- path_name_type_list: Collection[Tuple[str, str, str]],
- preprocess: Callable[
- [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
- ] = None,
- float_dtype: str = "float32",
- int_dtype: str = "long",
- max_cache_size: Union[float, int, str] = 0.0,
- max_cache_fd: int = 0,
- dest_sample_rate: int = 16000,
+ self,
+ path_name_type_list: Collection[Tuple[str, str, str]],
+ preprocess: Callable[
+ [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
+ ] = None,
+ float_dtype: str = "float32",
+ int_dtype: str = "long",
+ dest_sample_rate: int = 16000,
+ speed_perturb: Union[list, tuple] = None,
+ mode: str = "train",
):
- assert check_argument_types()
if len(path_name_type_list) == 0:
raise ValueError(
'1 or more elements are required for "path_name_type_list"'
@@ -289,8 +119,11 @@
self.float_dtype = float_dtype
self.int_dtype = int_dtype
- self.max_cache_fd = max_cache_fd
self.dest_sample_rate = dest_sample_rate
+ self.speed_perturb = speed_perturb
+ self.mode = mode
+ if self.speed_perturb is not None:
+ logging.info("Using speed_perturb: {}".format(speed_perturb))
self.loader_dict = {}
self.debug_info = {}
@@ -304,54 +137,51 @@
if len(self.loader_dict[name]) == 0:
raise RuntimeError(f"{path} has no samples")
- # TODO(kamo): Should check consistency of each utt-keys?
-
- if isinstance(max_cache_size, str):
- max_cache_size = humanfriendly.parse_size(max_cache_size)
- self.max_cache_size = max_cache_size
- if max_cache_size > 0:
- self.cache = SizedDict(shared=True)
- else:
- self.cache = None
-
def _build_loader(
- self, path: str, loader_type: str
- ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
+ self, path: str, loader_type: str
+ ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, List[int], numbers.Number]]:
"""Helper function to instantiate Loader.
Args:
path: The file path
- loader_type: loader_type. sound, npy, text_int, text_float, etc
+ loader_type: loader_type. sound, npy, text, etc
"""
- for key, dic in DATA_TYPES.items():
- # e.g. loader_type="sound"
- # -> return DATA_TYPES["sound"]["func"](path)
- if re.match(key, loader_type):
- kwargs = {}
- for key2 in dic["kwargs"]:
- if key2 == "loader_type":
- kwargs["loader_type"] = loader_type
- elif key2 == "dest_sample_rate" and loader_type=="sound":
- kwargs["dest_sample_rate"] = self.dest_sample_rate
- elif key2 == "float_dtype":
- kwargs["float_dtype"] = self.float_dtype
- elif key2 == "int_dtype":
- kwargs["int_dtype"] = self.int_dtype
- elif key2 == "max_cache_fd":
- kwargs["max_cache_fd"] = self.max_cache_fd
+ if loader_type == "sound":
+ speed_perturb = self.speed_perturb if self.mode == "train" else None
+ loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False,
+ speed_perturb=speed_perturb)
+ return AdapterForSoundScpReader(loader, self.float_dtype)
+ elif loader_type == "kaldi_ark":
+ loader = kaldiio.load_scp(path)
+ return AdapterForSoundScpReader(loader, self.float_dtype)
+ elif loader_type == "npy":
+ return NpyScpReader(path)
+ elif loader_type == "text":
+ text_loader = {}
+ with open(path, "r", encoding="utf-8") as f:
+ for linenum, line in enumerate(f, 1):
+ sps = line.rstrip().split(maxsplit=1)
+ if len(sps) == 1:
+ k, v = sps[0], ""
else:
- raise RuntimeError(f"Not implemented keyword argument: {key2}")
-
- func = dic["func"]
- try:
- return func(path, **kwargs)
- except Exception:
- if hasattr(func, "__name__"):
- name = func.__name__
+ k, v = sps
+ if k in text_loader:
+ raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
+ text_loader[k] = v
+ return text_loader
+ elif loader_type == "text_int":
+ text_int_loader = {}
+ with open(path, "r", encoding="utf-8") as f:
+ for linenum, line in enumerate(f, 1):
+ sps = line.rstrip().split(maxsplit=1)
+ if len(sps) == 1:
+ k, v = sps[0], ""
else:
- name = str(func)
- logging.error(f"An error happened with {name}({path})")
- raise
+ k, v = sps
+ if k in text_int_loader:
+ raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
+ text_int_loader[k] = [int(i) for i in v.split()]
+ return text_int_loader
else:
raise RuntimeError(f"Not supported: loader_type={loader_type}")
@@ -373,16 +203,11 @@
return _mes
def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
- assert check_argument_types()
# Change integer-id to string-id
if isinstance(uid, int):
d = next(iter(self.loader_dict.values()))
uid = list(d)[uid]
-
- if self.cache is not None and uid in self.cache:
- data = self.cache[uid]
- return uid, data
data = {}
# 1. Load data from each loaders
@@ -392,7 +217,7 @@
if isinstance(value, (list, tuple)):
value = np.array(value)
if not isinstance(
- value, (np.ndarray, torch.Tensor, str, numbers.Number)
+ value, (np.ndarray, torch.Tensor, str, numbers.Number)
):
raise TypeError(
f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
@@ -434,9 +259,5 @@
raise NotImplementedError(f"Not supported dtype: {value.dtype}")
data[name] = value
- if self.cache is not None and self.cache.size < self.max_cache_size:
- self.cache[uid] = data
-
retval = uid, data
- assert check_return_type(retval)
return retval
--
Gitblit v1.9.1