From 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 16 五月 2023 15:07:20 +0800
Subject: [PATCH] Merge pull request #518 from alibaba-damo-academy/dev_wjm2
---
funasr/datasets/small_datasets/dataset.py | 258 +++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 258 insertions(+), 0 deletions(-)
diff --git a/funasr/datasets/small_datasets/dataset.py b/funasr/datasets/small_datasets/dataset.py
new file mode 100644
index 0000000..e14e4f1
--- /dev/null
+++ b/funasr/datasets/small_datasets/dataset.py
@@ -0,0 +1,258 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import collections
+import copy
+import logging
+import numbers
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import Mapping
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import kaldiio
+import numpy as np
+import torch
+from torch.utils.data.dataset import Dataset
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.npy_scp import NpyScpReader
+from funasr.fileio.sound_scp import SoundScpReader
+
+
+class AdapterForSoundScpReader(collections.abc.Mapping):
+ def __init__(self, loader, dtype=None):
+ assert check_argument_types()
+ self.loader = loader
+ self.dtype = dtype
+ self.rate = None
+
+ def keys(self):
+ return self.loader.keys()
+
+ def __len__(self):
+ return len(self.loader)
+
+ def __iter__(self):
+ return iter(self.loader)
+
+ def __getitem__(self, key: str) -> np.ndarray:
+ retval = self.loader[key]
+
+ if isinstance(retval, tuple):
+ assert len(retval) == 2, len(retval)
+ if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
+ # sound scp case
+ rate, array = retval
+ elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
+ # Extended ark format case
+ array, rate = retval
+ else:
+ raise RuntimeError(
+ f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
+ )
+
+ if self.rate is not None and self.rate != rate:
+ raise RuntimeError(
+ f"Sampling rates are mismatched: {self.rate} != {rate}"
+ )
+ self.rate = rate
+ # Multichannel wave fie
+ # array: (NSample, Channel) or (Nsample)
+ if self.dtype is not None:
+ array = array.astype(self.dtype)
+
+ else:
+ # Normal ark case
+ assert isinstance(retval, np.ndarray), type(retval)
+ array = retval
+ if self.dtype is not None:
+ array = array.astype(self.dtype)
+
+ assert isinstance(array, np.ndarray), type(array)
+ return array
+
+
+def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
+ # The file is as follows:
+ # utterance_id_A /some/where/a.wav
+ # utterance_id_B /some/where/a.flac
+
+ # NOTE(kamo): SoundScpReader doesn't support pipe-fashion
+ # like Kaldi e.g. "cat a.wav |".
+ # NOTE(kamo): The audio signal is normalized to [-1,1] range.
+ loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
+
+ # SoundScpReader.__getitem__() returns Tuple[int, ndarray],
+ # but ndarray is desired, so Adapter class is inserted here
+ return AdapterForSoundScpReader(loader, float_dtype)
+
+
+def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
+ loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
+ return AdapterForSoundScpReader(loader, float_dtype)
+
+
+class ESPnetDataset(Dataset):
+ """
+ Pytorch Dataset class for FunASR, modified from ESPnet
+ """
+
+ def __init__(
+ self,
+ path_name_type_list: Collection[Tuple[str, str, str]],
+ preprocess: Callable[
+ [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
+ ] = None,
+ float_dtype: str = "float32",
+ int_dtype: str = "long",
+ dest_sample_rate: int = 16000,
+ speed_perturb: Union[list, tuple] = None,
+ mode: str = "train",
+ ):
+ assert check_argument_types()
+ if len(path_name_type_list) == 0:
+ raise ValueError(
+ '1 or more elements are required for "path_name_type_list"'
+ )
+
+ path_name_type_list = copy.deepcopy(path_name_type_list)
+ self.preprocess = preprocess
+
+ self.float_dtype = float_dtype
+ self.int_dtype = int_dtype
+ self.dest_sample_rate = dest_sample_rate
+ self.speed_perturb = speed_perturb
+ self.mode = mode
+ if self.speed_perturb is not None:
+ logging.info("Using speed_perturb: {}".format(speed_perturb))
+
+ self.loader_dict = {}
+ self.debug_info = {}
+ for path, name, _type in path_name_type_list:
+ if name in self.loader_dict:
+ raise RuntimeError(f'"{name}" is duplicated for data-key')
+
+ loader = self._build_loader(path, _type)
+ self.loader_dict[name] = loader
+ self.debug_info[name] = path, _type
+ if len(self.loader_dict[name]) == 0:
+ raise RuntimeError(f"{path} has no samples")
+
+ def _build_loader(
+ self, path: str, loader_type: str
+ ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
+ """Helper function to instantiate Loader.
+
+ Args:
+ path: The file path
+ loader_type: loader_type. sound, npy, text, etc
+ """
+ if loader_type == "sound":
+ speed_perturb = self.speed_perturb if self.mode == "train" else None
+ loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False,
+ speed_perturb=speed_perturb)
+ return AdapterForSoundScpReader(loader, self.float_dtype)
+ elif loader_type == "kaldi_ark":
+ loader = kaldiio.load_scp(path)
+ return AdapterForSoundScpReader(loader, self.float_dtype)
+ elif loader_type == "npy":
+ return NpyScpReader(path)
+ elif loader_type == "text":
+ text_loader = {}
+ with open(path, "r", encoding="utf-8") as f:
+ for linenum, line in enumerate(f, 1):
+ sps = line.rstrip().split(maxsplit=1)
+ if len(sps) == 1:
+ k, v = sps[0], ""
+ else:
+ k, v = sps
+ if k in text_loader:
+ raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
+ text_loader[k] = v
+ return text_loader
+ else:
+ raise RuntimeError(f"Not supported: loader_type={loader_type}")
+
+ def has_name(self, name) -> bool:
+ return name in self.loader_dict
+
+ def names(self) -> Tuple[str, ...]:
+ return tuple(self.loader_dict)
+
+ def __iter__(self):
+ return iter(next(iter(self.loader_dict.values())))
+
+ def __repr__(self):
+ _mes = self.__class__.__name__
+ _mes += "("
+ for name, (path, _type) in self.debug_info.items():
+ _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
+ _mes += f"\n preprocess: {self.preprocess})"
+ return _mes
+
+ def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
+ assert check_argument_types()
+
+ # Change integer-id to string-id
+ if isinstance(uid, int):
+ d = next(iter(self.loader_dict.values()))
+ uid = list(d)[uid]
+
+ data = {}
+ # 1. Load data from each loaders
+ for name, loader in self.loader_dict.items():
+ try:
+ value = loader[uid]
+ if isinstance(value, (list, tuple)):
+ value = np.array(value)
+ if not isinstance(
+ value, (np.ndarray, torch.Tensor, str, numbers.Number)
+ ):
+ raise TypeError(
+ f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
+ )
+ except Exception:
+ path, _type = self.debug_info[name]
+ logging.error(
+ f"Error happened with path={path}, type={_type}, id={uid}"
+ )
+ raise
+
+ # torch.Tensor is converted to ndarray
+ if isinstance(value, torch.Tensor):
+ value = value.numpy()
+ elif isinstance(value, numbers.Number):
+ value = np.array([value])
+ data[name] = value
+
+ # 2. [Option] Apply preprocessing
+ # e.g. funasr.train.preprocessor:CommonPreprocessor
+ if self.preprocess is not None:
+ data = self.preprocess(uid, data)
+
+ # 3. Force data-precision
+ for name in data:
+ value = data[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f"All values must be converted to np.ndarray object "
+ f'by preprocessing, but "{name}" is still {type(value)}.'
+ )
+
+ # Cast to desired type
+ if value.dtype.kind == "f":
+ value = value.astype(self.float_dtype)
+ elif value.dtype.kind == "i":
+ value = value.astype(self.int_dtype)
+ else:
+ raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+ data[name] = value
+
+ retval = uid, data
+ assert check_return_type(retval)
+ return retval
--
Gitblit v1.9.1