From 580b11b57ac4b62f7e2acda73813a4e10e8e4cd3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 10 十月 2023 17:17:29 +0800
Subject: [PATCH] v0.8.0

---
 funasr/datasets/iterable_dataset.py |  367 +++++++++++++++++++++++++++++++++++++--------------
 1 files changed, 263 insertions(+), 104 deletions(-)

diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 319dd7f..6398e0c 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -8,16 +8,20 @@
 from typing import Iterator
 from typing import Tuple
 from typing import Union
+from typing import List
 
 import kaldiio
 import numpy as np
-import soundfile
 import torch
+import torchaudio
+import soundfile
 from torch.utils.data.dataset import IterableDataset
-from typeguard import check_argument_types
+import os.path
 
 from funasr.datasets.dataset import ESPnetDataset
 
+
+SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
 
 def load_kaldi(input):
     retval = kaldiio.load_mat(input)
@@ -42,9 +46,41 @@
     return array
 
 
+def load_bytes(input):
+    middle_data = np.frombuffer(input, dtype=np.int16)
+    middle_data = np.asarray(middle_data)
+    if middle_data.dtype.kind not in 'iu':
+        raise TypeError("'middle_data' must be an array of integers")
+    dtype = np.dtype('float32')
+    if dtype.kind != 'f':
+        raise TypeError("'dtype' must be a floating point type")
+
+    i = np.iinfo(middle_data.dtype)
+    abs_max = 2 ** (i.bits - 1)
+    offset = i.min + abs_max
+    array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
+    return array
+
+def load_pcm(input):
+    with open(input,"rb") as f:
+        bytes = f.read()
+    return load_bytes(bytes)
+
+def load_wav(input):
+    try:
+        return torchaudio.load(input)[0].numpy()
+    except:
+        waveform, _ = soundfile.read(input, dtype='float32')
+        if waveform.ndim == 2:
+            waveform = waveform[:, 0]
+        return np.expand_dims(waveform, axis=0)
+
 DATA_TYPES = {
-    "sound": lambda x: soundfile.read(x)[0],
+    "sound": load_wav,
+    "pcm": load_pcm,
     "kaldi_ark": load_kaldi,
+    "bytes": load_bytes,
+    "waveform": lambda x: x,
     "npy": np.load,
     "text_int": lambda x: np.loadtxt(
         StringIO(x), ndmin=1, dtype=np.long, delimiter=" "
@@ -73,16 +109,17 @@
     """
 
     def __init__(
-        self,
-        path_name_type_list: Collection[Tuple[str, str, str]],
-        preprocess: Callable[
-            [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
-        ] = None,
-        float_dtype: str = "float32",
-        int_dtype: str = "long",
-        key_file: str = None,
+            self,
+            path_name_type_list: Collection[Tuple[any, str, str]],
+            preprocess: Callable[
+                [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
+            ] = None,
+            float_dtype: str = "float32",
+            fs: dict = None,
+            mc: bool = False,
+            int_dtype: str = "long",
+            key_file: str = None,
     ):
-        assert check_argument_types()
         if len(path_name_type_list) == 0:
             raise ValueError(
                 '1 or more elements are required for "path_name_type_list"'
@@ -94,19 +131,29 @@
         self.float_dtype = float_dtype
         self.int_dtype = int_dtype
         self.key_file = key_file
+        self.fs = fs
+        self.mc = mc
 
         self.debug_info = {}
         non_iterable_list = []
         self.path_name_type_list = []
 
-        for path, name, _type in path_name_type_list:
-            if name in self.debug_info:
-                raise RuntimeError(f'"{name}" is duplicated for data-key')
+        if not isinstance(path_name_type_list[0], (Tuple, List)):
+            path = path_name_type_list[0]
+            name = path_name_type_list[1]
+            _type = path_name_type_list[2]
             self.debug_info[name] = path, _type
             if _type not in DATA_TYPES:
                 non_iterable_list.append((path, name, _type))
             else:
                 self.path_name_type_list.append((path, name, _type))
+        else:
+            for path, name, _type in path_name_type_list:
+                self.debug_info[name] = path, _type
+                if _type not in DATA_TYPES:
+                    non_iterable_list.append((path, name, _type))
+                else:
+                    self.path_name_type_list.append((path, name, _type))
 
         if len(non_iterable_list) != 0:
             # Some types doesn't support iterable mode
@@ -119,10 +166,7 @@
         else:
             self.non_iterable_dataset = None
 
-        if Path(Path(path_name_type_list[0][0]).parent, "utt2category").exists():
-            self.apply_utt2category = True
-        else:
-            self.apply_utt2category = False
+        self.apply_utt2category = False
 
     def has_name(self, name) -> bool:
         return name in self.debug_info
@@ -139,99 +183,214 @@
         return _mes
 
     def __iter__(self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
-        if self.key_file is not None:
-            uid_iter = (
-                line.rstrip().split(maxsplit=1)[0]
-                for line in open(self.key_file, encoding="utf-8")
-            )
-        elif len(self.path_name_type_list) != 0:
-            uid_iter = (
-                line.rstrip().split(maxsplit=1)[0]
-                for line in open(self.path_name_type_list[0][0], encoding="utf-8")
-            )
-        else:
-            uid_iter = iter(self.non_iterable_dataset)
-
-        files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
-
-        worker_info = torch.utils.data.get_worker_info()
-
-        linenum = 0
         count = 0
-        for count, uid in enumerate(uid_iter, 1):
-            # If num_workers>=1, split keys
-            if worker_info is not None:
-                if (count - 1) % worker_info.num_workers != worker_info.id:
-                    continue
-
-            # 1. Read a line from each file
-            while True:
-                keys = []
-                values = []
-                for f in files:
-                    linenum += 1
-                    try:
-                        line = next(f)
-                    except StopIteration:
-                        raise RuntimeError(f"{uid} is not found in the files")
-                    sps = line.rstrip().split(maxsplit=1)
-                    if len(sps) != 2:
-                        raise RuntimeError(
-                            f"This line doesn't include a space:"
-                            f" {f}:L{linenum}: {line})"
-                        )
-                    key, value = sps
-                    keys.append(key)
-                    values.append(value)
-
-                for k_idx, k in enumerate(keys):
-                    if k != keys[0]:
-                        raise RuntimeError(
-                            f"Keys are mismatched. Text files (idx={k_idx}) is "
-                            f"not sorted or not having same keys at L{linenum}"
-                        )
-
-                # If the key is matched, break the loop
-                if len(keys) == 0 or keys[0] == uid:
-                    break
-
-            # 2. Load the entry from each line and create a dict
+        if len(self.path_name_type_list) != 0 and (self.path_name_type_list[0][2] == "bytes" or self.path_name_type_list[0][2] == "waveform"):
+            linenum = len(self.path_name_type_list)
             data = {}
-            # 2.a. Load data streamingly
-            for value, (path, name, _type) in zip(values, self.path_name_type_list):
+            for i in range(linenum):
+                value = self.path_name_type_list[i][0]
+                uid = 'utt_id'
+                name = self.path_name_type_list[i][1]
+                _type = self.path_name_type_list[i][2]
                 func = DATA_TYPES[_type]
-                # Load entry
                 array = func(value)
+                if self.fs is not None and (name == "speech" or name == "ref_speech"):
+                    audio_fs = self.fs["audio_fs"]
+                    model_fs = self.fs["model_fs"]
+                    if audio_fs is not None and model_fs is not None:
+                        array = torch.from_numpy(array)
+                        array = array.unsqueeze(0)
+                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
+                                                       new_freq=model_fs)(array)
+                        array = array.squeeze(0).numpy()
+
                 data[name] = array
-            if self.non_iterable_dataset is not None:
-                # 2.b. Load data from non-iterable dataset
-                _, from_non_iterable = self.non_iterable_dataset[uid]
-                data.update(from_non_iterable)
 
-            # 3. [Option] Apply preprocessing
-            #   e.g. funasr.train.preprocessor:CommonPreprocessor
-            if self.preprocess is not None:
-                data = self.preprocess(uid, data)
-
-            # 4. Force data-precision
-            for name in data:
-                value = data[name]
-                if not isinstance(value, np.ndarray):
-                    raise RuntimeError(
-                        f"All values must be converted to np.ndarray object "
-                        f'by preprocessing, but "{name}" is still {type(value)}.'
-                    )
-
-                # Cast to desired type
-                if value.dtype.kind == "f":
-                    value = value.astype(self.float_dtype)
-                elif value.dtype.kind == "i":
-                    value = value.astype(self.int_dtype)
-                else:
-                    raise NotImplementedError(f"Not supported dtype: {value.dtype}")
-                data[name] = value
+                if self.preprocess is not None:
+                    data = self.preprocess(uid, data)
+                for name in data:
+                    count += 1
+                    value = data[name]
+                    if not isinstance(value, np.ndarray):
+                        raise RuntimeError(
+                            f'All values must be converted to np.ndarray object '
+                            f'by preprocessing, but "{name}" is still {type(value)}.')
+                    # Cast to desired type
+                    if value.dtype.kind == 'f':
+                        value = value.astype(self.float_dtype)
+                    elif value.dtype.kind == 'i':
+                        value = value.astype(self.int_dtype)
+                    else:
+                        raise NotImplementedError(
+                            f'Not supported dtype: {value.dtype}')
+                    data[name] = value
 
             yield uid, data
 
+        elif len(self.path_name_type_list) != 0 and self.path_name_type_list[0][2] == "sound" and not self.path_name_type_list[0][0].lower().endswith(".scp"):
+            linenum = len(self.path_name_type_list)
+            data = {}
+            for i in range(linenum):
+                value = self.path_name_type_list[i][0]
+                uid = os.path.basename(self.path_name_type_list[i][0]).split(".")[0]
+                name = self.path_name_type_list[i][1]
+                _type = self.path_name_type_list[i][2]
+                if _type == "sound":
+                   audio_type = os.path.basename(value).lower()
+                   if audio_type.rfind(".pcm") >= 0:
+                       _type = "pcm"
+                func = DATA_TYPES[_type]
+                array = func(value)
+                if self.fs is not None and (name == "speech" or name == "ref_speech"):
+                    audio_fs = self.fs["audio_fs"]
+                    model_fs = self.fs["model_fs"]
+                    if audio_fs is not None and model_fs is not None:
+                        array = torch.from_numpy(array)
+                        array = torchaudio.transforms.Resample(orig_freq=audio_fs,
+                                                               new_freq=model_fs)(array)
+                        array = array.numpy()
+                        
+                if _type == "sound":
+                    if self.mc:
+                        data[name] = array.transpose((1, 0))
+                    else:
+                        data[name] = array[0]
+                else:
+                    data[name] = array
+
+                if self.preprocess is not None:
+                    data = self.preprocess(uid, data)
+                for name in data:
+                    count += 1
+                    value = data[name]
+                    if not isinstance(value, np.ndarray):
+                        raise RuntimeError(
+                            f'All values must be converted to np.ndarray object '
+                            f'by preprocessing, but "{name}" is still {type(value)}.')
+                    # Cast to desired type
+                    if value.dtype.kind == 'f':
+                        value = value.astype(self.float_dtype)
+                    elif value.dtype.kind == 'i':
+                        value = value.astype(self.int_dtype)
+                    else:
+                        raise NotImplementedError(
+                            f'Not supported dtype: {value.dtype}')
+                    data[name] = value
+
+            yield uid, data
+
+        else:
+            if self.key_file is not None:
+                uid_iter = (
+                    line.rstrip().split(maxsplit=1)[0]
+                    for line in open(self.key_file, encoding="utf-8")
+                )
+            elif len(self.path_name_type_list) != 0:
+                uid_iter = (
+                    line.rstrip().split(maxsplit=1)[0]
+                    for line in open(self.path_name_type_list[0][0], encoding="utf-8")
+                )
+            else:
+                uid_iter = iter(self.non_iterable_dataset)
+
+            files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list]
+
+            worker_info = torch.utils.data.get_worker_info()
+
+            linenum = 0
+            for count, uid in enumerate(uid_iter, 1):
+                # If num_workers>=1, split keys
+                if worker_info is not None:
+                    if (count - 1) % worker_info.num_workers != worker_info.id:
+                        continue
+
+                # 1. Read a line from each file
+                while True:
+                    keys = []
+                    values = []
+                    for f in files:
+                        linenum += 1
+                        try:
+                            line = next(f)
+                        except StopIteration:
+                            raise RuntimeError(f"{uid} is not found in the files")
+                        sps = line.rstrip().split(maxsplit=1)
+                        if len(sps) != 2:
+                            raise RuntimeError(
+                                f"This line doesn't include a space:"
+                                f" {f}:L{linenum}: {line})"
+                            )
+                        key, value = sps
+                        keys.append(key)
+                        values.append(value)
+
+                    for k_idx, k in enumerate(keys):
+                        if k != keys[0]:
+                            raise RuntimeError(
+                                f"Keys are mismatched. Text files (idx={k_idx}) is "
+                                f"not sorted or not having same keys at L{linenum}"
+                            )
+
+                    # If the key is matched, break the loop
+                    if len(keys) == 0 or keys[0] == uid:
+                        break
+
+                # 2. Load the entry from each line and create a dict
+                data = {}
+                # 2.a. Load data streamingly
+                for value, (path, name, _type) in zip(values, self.path_name_type_list):
+                    if _type == "sound":
+                        audio_type = os.path.basename(value).lower()
+                        if audio_type.rfind(".pcm") >= 0:
+                            _type = "pcm"
+                    func = DATA_TYPES[_type]
+                    # Load entry
+                    array = func(value)
+                    if self.fs is not None and name == "speech":
+                        audio_fs = self.fs["audio_fs"]
+                        model_fs = self.fs["model_fs"]
+                        if audio_fs is not None and model_fs is not None:
+                            array = torch.from_numpy(array)
+                            array = torchaudio.transforms.Resample(orig_freq=audio_fs,
+                                                                   new_freq=model_fs)(array)
+                            array = array.numpy()
+                    if _type == "sound":
+                        if self.mc:
+                            data[name] = array.transpose((1, 0))
+                        else:
+                            data[name] = array[0]
+                    else:
+                        data[name] = array
+                if self.non_iterable_dataset is not None:
+                    # 2.b. Load data from non-iterable dataset
+                    _, from_non_iterable = self.non_iterable_dataset[uid]
+                    data.update(from_non_iterable)
+
+                # 3. [Option] Apply preprocessing
+                #   e.g. funasr.train.preprocessor:CommonPreprocessor
+                if self.preprocess is not None:
+                    data = self.preprocess(uid, data)
+
+                # 4. Force data-precision
+                for name in data:
+                    value = data[name]
+                    if not isinstance(value, np.ndarray):
+                        raise RuntimeError(
+                            f"All values must be converted to np.ndarray object "
+                            f'by preprocessing, but "{name}" is still {type(value)}.'
+                        )
+
+                    # Cast to desired type
+                    if value.dtype.kind == "f":
+                        value = value.astype(self.float_dtype)
+                    elif value.dtype.kind == "i":
+                        value = value.astype(self.int_dtype)
+                    else:
+                        raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+                    data[name] = value
+
+                yield uid, data
+
         if count == 0:
             raise RuntimeError("No iteration")
+

--
Gitblit v1.9.1