From fcc9c89eaba9a4e36c54958aeedeec7ab3756cd7 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 13 二月 2023 17:43:31 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/datasets/iterable_dataset.py | 326 ++++++++++++++++++++++++++++++++++++++---------------
1 files changed, 232 insertions(+), 94 deletions(-)
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 319dd7f..fa0adeb 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -11,13 +11,16 @@
import kaldiio
import numpy as np
-import soundfile
import torch
+import torchaudio
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
+import os.path
from funasr.datasets.dataset import ESPnetDataset
+
+SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
def load_kaldi(input):
retval = kaldiio.load_mat(input)
@@ -42,9 +45,32 @@
return array
+def load_bytes(input):
+ middle_data = np.frombuffer(input, dtype=np.int16)
+ middle_data = np.asarray(middle_data)
+ if middle_data.dtype.kind not in 'iu':
+ raise TypeError("'middle_data' must be an array of integers")
+ dtype = np.dtype('float32')
+ if dtype.kind != 'f':
+ raise TypeError("'dtype' must be a floating point type")
+
+ i = np.iinfo(middle_data.dtype)
+ abs_max = 2 ** (i.bits - 1)
+ offset = i.min + abs_max
+ array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
+ return array
+
+def load_pcm(input):
+ with open(input,"rb") as f:
+ bytes = f.read()
+ return load_bytes(bytes)
+
DATA_TYPES = {
- "sound": lambda x: soundfile.read(x)[0],
+ "sound": lambda x: torchaudio.load(x)[0][0].numpy(),
+ "pcm": load_pcm,
"kaldi_ark": load_kaldi,
+ "bytes": load_bytes,
+ "waveform": lambda x: x,
"npy": np.load,
"text_int": lambda x: np.loadtxt(
StringIO(x), ndmin=1, dtype=np.long, delimiter=" "
@@ -73,14 +99,15 @@
"""
def __init__(
- self,
- path_name_type_list: Collection[Tuple[str, str, str]],
- preprocess: Callable[
- [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
- ] = None,
- float_dtype: str = "float32",
- int_dtype: str = "long",
- key_file: str = None,
+ self,
+ path_name_type_list: Collection[Tuple[any, str, str]],
+ preprocess: Callable[
+ [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
+ ] = None,
+ float_dtype: str = "float32",
+ fs: dict = None,
+ int_dtype: str = "long",
+ key_file: str = None,
):
assert check_argument_types()
if len(path_name_type_list) == 0:
@@ -94,19 +121,28 @@
self.float_dtype = float_dtype
self.int_dtype = int_dtype
self.key_file = key_file
+ self.fs = fs
self.debug_info = {}
non_iterable_list = []
self.path_name_type_list = []
- for path, name, _type in path_name_type_list:
- if name in self.debug_info:
- raise RuntimeError(f'"{name}" is duplicated for data-key')
+ if not isinstance(path_name_type_list[0], Tuple):
+ path = path_name_type_list[0]
+ name = path_name_type_list[1]
+ _type = path_name_type_list[2]
self.debug_info[name] = path, _type
if _type not in DATA_TYPES:
non_iterable_list.append((path, name, _type))
else:
self.path_name_type_list.append((path, name, _type))
+ else:
+ for path, name, _type in path_name_type_list:
+ self.debug_info[name] = path, _type
+ if _type not in DATA_TYPES:
+ non_iterable_list.append((path, name, _type))
+ else:
+ self.path_name_type_list.append((path, name, _type))
if len(non_iterable_list) != 0:
# Some types doesn't support iterable mode
@@ -119,10 +155,7 @@
else:
self.non_iterable_dataset = None
- if Path(Path(path_name_type_list[0][0]).parent, "utt2category").exists():
- self.apply_utt2category = True
- else:
- self.apply_utt2category = False
+ self.apply_utt2category = False
def has_name(self, name) -> bool:
return name in self.debug_info
@@ -139,99 +172,204 @@
return _mes
def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
- if self.key_file is not None:
- uid_iter = (
- line.rstrip().split(maxsplit=1)[0]
- for line in open(self.key_file, encoding="utf-8")
- )
- elif len(self.path_name_type_list) != 0:
- uid_iter = (
- line.rstrip().split(maxsplit=1)[0]
- for line in open(self.path_name_type_list[0][0], encoding="utf-8")
- )
- else:
- uid_iter = iter(self.non_iterable_dataset)
-
- files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
-
- worker_info = torch.utils.data.get_worker_info()
-
- linenum = 0
count = 0
- for count, uid in enumerate(uid_iter, 1):
- # If num_workers>=1, split keys
- if worker_info is not None:
- if (count - 1) % worker_info.num_workers != worker_info.id:
- continue
-
- # 1. Read a line from each file
- while True:
- keys = []
- values = []
- for f in files:
- linenum += 1
- try:
- line = next(f)
- except StopIteration:
- raise RuntimeError(f"{uid} is not found in the files")
- sps = line.rstrip().split(maxsplit=1)
- if len(sps) != 2:
- raise RuntimeError(
- f"This line doesn't include a space:"
- f" {f}:L{linenum}: {line})"
- )
- key, value = sps
- keys.append(key)
- values.append(value)
-
- for k_idx, k in enumerate(keys):
- if k != keys[0]:
- raise RuntimeError(
- f"Keys are mismatched. Text files (idx={k_idx}) is "
- f"not sorted or not having same keys at L{linenum}"
- )
-
- # If the key is matched, break the loop
- if len(keys) == 0 or keys[0] == uid:
- break
-
- # 2. Load the entry from each line and create a dict
+ if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"):
data = {}
- # 2.a. Load data streamingly
- for value, (path, name, _type) in zip(values, self.path_name_type_list):
- func = DATA_TYPES[_type]
- # Load entry
- array = func(value)
- data[name] = array
- if self.non_iterable_dataset is not None:
- # 2.b. Load data from non-iterable dataset
- _, from_non_iterable = self.non_iterable_dataset[uid]
- data.update(from_non_iterable)
+ value = self.path_name_type_list[0][0]
+ uid = 'utt_id'
+ name = self.path_name_type_list[0][1]
+ _type = self.path_name_type_list[0][2]
+ func = DATA_TYPES[_type]
+ array = func(value)
+ if self.fs is not None and name == "speech":
+ audio_fs = self.fs["audio_fs"]
+ model_fs = self.fs["model_fs"]
+ if audio_fs is not None and model_fs is not None:
+ array = torch.from_numpy(array)
+ array = array.unsqueeze(0)
+ array = torchaudio.transforms.Resample(orig_freq=audio_fs,
+ new_freq=model_fs)(array)
+ array = array.squeeze(0).numpy()
+ data[name] = array
- # 3. [Option] Apply preprocessing
- # e.g. funasr.train.preprocessor:CommonPreprocessor
if self.preprocess is not None:
data = self.preprocess(uid, data)
-
- # 4. Force data-precision
for name in data:
+ count += 1
value = data[name]
if not isinstance(value, np.ndarray):
raise RuntimeError(
- f"All values must be converted to np.ndarray object "
- f'by preprocessing, but "{name}" is still {type(value)}.'
- )
-
+ f'All values must be converted to np.ndarray object '
+ f'by preprocessing, but "{name}" is still {type(value)}.')
# Cast to desired type
- if value.dtype.kind == "f":
+ if value.dtype.kind == 'f':
value = value.astype(self.float_dtype)
- elif value.dtype.kind == "i":
+ elif value.dtype.kind == 'i':
value = value.astype(self.int_dtype)
else:
- raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+ raise NotImplementedError(
+ f'Not supported dtype: {value.dtype}')
data[name] = value
yield uid, data
+ elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"):
+ data = {}
+ value = self.path_name_type_list[0][0]
+ uid = os.path.basename(self.path_name_type_list[0][0]).split(".")[0]
+ name = self.path_name_type_list[0][1]
+ _type = self.path_name_type_list[0][2]
+ if _type == "sound":
+ audio_type = os.path.basename(value).split(".")[1].lower()
+ if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
+ raise NotImplementedError(
+ f'Not supported audio type: {audio_type}')
+ if audio_type == "pcm":
+ _type = "pcm"
+
+ func = DATA_TYPES[_type]
+ array = func(value)
+ if self.fs is not None and name == "speech":
+ audio_fs = self.fs["audio_fs"]
+ model_fs = self.fs["model_fs"]
+ if audio_fs is not None and model_fs is not None:
+ array = torch.from_numpy(array)
+ array = array.unsqueeze(0)
+ array = torchaudio.transforms.Resample(orig_freq=audio_fs,
+ new_freq=model_fs)(array)
+ array = array.squeeze(0).numpy()
+ data[name] = array
+
+ if self.preprocess is not None:
+ data = self.preprocess(uid, data)
+ for name in data:
+ count += 1
+ value = data[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f'All values must be converted to np.ndarray object '
+ f'by preprocessing, but "{name}" is still {type(value)}.')
+ # Cast to desired type
+ if value.dtype.kind == 'f':
+ value = value.astype(self.float_dtype)
+ elif value.dtype.kind == 'i':
+ value = value.astype(self.int_dtype)
+ else:
+ raise NotImplementedError(
+ f'Not supported dtype: {value.dtype}')
+ data[name] = value
+
+ yield uid, data
+
+ else:
+ if self.key_file is not None:
+ uid_iter = (
+ line.rstrip().split(maxsplit=1)[0]
+ for line in open(self.key_file, encoding="utf-8")
+ )
+ elif len(self.path_name_type_list) != 0:
+ uid_iter = (
+ line.rstrip().split(maxsplit=1)[0]
+ for line in open(self.path_name_type_list[0][0], encoding="utf-8")
+ )
+ else:
+ uid_iter = iter(self.non_iterable_dataset)
+
+ files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
+
+ worker_info = torch.utils.data.get_worker_info()
+
+ linenum = 0
+ for count, uid in enumerate(uid_iter, 1):
+ # If num_workers>=1, split keys
+ if worker_info is not None:
+ if (count - 1) % worker_info.num_workers != worker_info.id:
+ continue
+
+ # 1. Read a line from each file
+ while True:
+ keys = []
+ values = []
+ for f in files:
+ linenum += 1
+ try:
+ line = next(f)
+ except StopIteration:
+ raise RuntimeError(f"{uid} is not found in the files")
+ sps = line.rstrip().split(maxsplit=1)
+ if len(sps) != 2:
+ raise RuntimeError(
+ f"This line doesn't include a space:"
+ f" {f}:L{linenum}: {line})"
+ )
+ key, value = sps
+ keys.append(key)
+ values.append(value)
+
+ for k_idx, k in enumerate(keys):
+ if k != keys[0]:
+ raise RuntimeError(
+ f"Keys are mismatched. Text files (idx={k_idx}) is "
+ f"not sorted or not having same keys at L{linenum}"
+ )
+
+ # If the key is matched, break the loop
+ if len(keys) == 0 or keys[0] == uid:
+ break
+
+ # 2. Load the entry from each line and create a dict
+ data = {}
+ # 2.a. Load data streamingly
+ for value, (path, name, _type) in zip(values, self.path_name_type_list):
+ if _type == "sound":
+ audio_type = os.path.basename(value).split(".")[1].lower()
+ if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
+ raise NotImplementedError(
+ f'Not supported audio type: {audio_type}')
+ if audio_type == "pcm":
+ _type = "pcm"
+ func = DATA_TYPES[_type]
+ # Load entry
+ array = func(value)
+ if self.fs is not None and name == "speech":
+ audio_fs = self.fs["audio_fs"]
+ model_fs = self.fs["model_fs"]
+ if audio_fs is not None and model_fs is not None:
+ array = torch.from_numpy(array)
+ array = array.unsqueeze(0)
+ array = torchaudio.transforms.Resample(orig_freq=audio_fs,
+ new_freq=model_fs)(array)
+ array = array.squeeze(0).numpy()
+ data[name] = array
+ if self.non_iterable_dataset is not None:
+ # 2.b. Load data from non-iterable dataset
+ _, from_non_iterable = self.non_iterable_dataset[uid]
+ data.update(from_non_iterable)
+
+ # 3. [Option] Apply preprocessing
+ # e.g. funasr.train.preprocessor:CommonPreprocessor
+ if self.preprocess is not None:
+ data = self.preprocess(uid, data)
+
+ # 4. Force data-precision
+ for name in data:
+ value = data[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f"All values must be converted to np.ndarray object "
+ f'by preprocessing, but "{name}" is still {type(value)}.'
+ )
+
+ # Cast to desired type
+ if value.dtype.kind == "f":
+ value = value.astype(self.float_dtype)
+ elif value.dtype.kind == "i":
+ value = value.astype(self.int_dtype)
+ else:
+ raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+ data[name] = value
+
+ yield uid, data
+
if count == 0:
raise RuntimeError("No iteration")
--
Gitblit v1.9.1