From 6427c834dfd97b1f05c6659cdc7ccf010bf82fe1 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 24 四月 2023 19:50:07 +0800
Subject: [PATCH] update

---
 funasr/datasets/small_datasets/dataset.py |  259 ++++++---------------------------------------------
 1 files changed, 33 insertions(+), 226 deletions(-)

diff --git a/funasr/datasets/small_datasets/dataset.py b/funasr/datasets/small_datasets/dataset.py
index 7ed37fa..a7017a5 100644
--- a/funasr/datasets/small_datasets/dataset.py
+++ b/funasr/datasets/small_datasets/dataset.py
@@ -1,15 +1,10 @@
 # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
 #  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
 
-from abc import ABC
-from abc import abstractmethod
 import collections
 import copy
-import functools
 import logging
 import numbers
-import re
-from typing import Any
 from typing import Callable
 from typing import Collection
 from typing import Dict
@@ -17,8 +12,6 @@
 from typing import Tuple
 from typing import Union
 
-import h5py
-import humanfriendly
 import kaldiio
 import numpy as np
 import torch
@@ -27,12 +20,7 @@
 from typeguard import check_return_type
 
 from funasr.fileio.npy_scp import NpyScpReader
-from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
-from funasr.fileio.rand_gen_dataset import IntRandomGenerateDataset
-from funasr.fileio.read_text import load_num_sequence_text
-from funasr.fileio.read_text import read_2column_text
 from funasr.fileio.sound_scp import SoundScpReader
-from funasr.utils.sized_dict import SizedDict
 
 
 class AdapterForSoundScpReader(collections.abc.Mapping):
@@ -88,25 +76,6 @@
         return array
 
 
-class H5FileWrapper:
-    def __init__(self, path: str):
-        self.path = path
-        self.h5_file = h5py.File(path, "r")
-
-    def __repr__(self) -> str:
-        return str(self.h5_file)
-
-    def __len__(self) -> int:
-        return len(self.h5_file)
-
-    def __iter__(self):
-        return iter(self.h5_file)
-
-    def __getitem__(self, key) -> np.ndarray:
-        value = self.h5_file[key]
-        return value[()]
-
-
 def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
     # The file is as follows:
     #   utterance_id_A /some/where/a.wav
@@ -127,156 +96,20 @@
     return AdapterForSoundScpReader(loader, float_dtype)
 
 
-def rand_int_loader(filepath, loader_type):
-    # e.g. rand_int_3_10
-    try:
-        low, high = map(int, loader_type[len("rand_int_") :].split("_"))
-    except ValueError:
-        raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
-    return IntRandomGenerateDataset(filepath, low, high)
-
-
-DATA_TYPES = {
-    "sound": dict(
-        func=sound_loader,
-        kwargs=["dest_sample_rate","float_dtype"],
-        help="Audio format types which supported by sndfile wav, flac, etc."
-        "\n\n"
-        "   utterance_id_a a.wav\n"
-        "   utterance_id_b b.wav\n"
-        "   ...",
-    ),
-    "kaldi_ark": dict(
-        func=kaldi_loader,
-        kwargs=["max_cache_fd"],
-        help="Kaldi-ark file type."
-        "\n\n"
-        "   utterance_id_A /some/where/a.ark:123\n"
-        "   utterance_id_B /some/where/a.ark:456\n"
-        "   ...",
-    ),
-    "npy": dict(
-        func=NpyScpReader,
-        kwargs=[],
-        help="Npy file format."
-        "\n\n"
-        "   utterance_id_A /some/where/a.npy\n"
-        "   utterance_id_B /some/where/b.npy\n"
-        "   ...",
-    ),
-    "text_int": dict(
-        func=functools.partial(load_num_sequence_text, loader_type="text_int"),
-        kwargs=[],
-        help="A text file in which is written a sequence of interger numbers "
-        "separated by space."
-        "\n\n"
-        "   utterance_id_A 12 0 1 3\n"
-        "   utterance_id_B 3 3 1\n"
-        "   ...",
-    ),
-    "csv_int": dict(
-        func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
-        kwargs=[],
-        help="A text file in which is written a sequence of interger numbers "
-        "separated by comma."
-        "\n\n"
-        "   utterance_id_A 100,80\n"
-        "   utterance_id_B 143,80\n"
-        "   ...",
-    ),
-    "text_float": dict(
-        func=functools.partial(load_num_sequence_text, loader_type="text_float"),
-        kwargs=[],
-        help="A text file in which is written a sequence of float numbers "
-        "separated by space."
-        "\n\n"
-        "   utterance_id_A 12. 3.1 3.4 4.4\n"
-        "   utterance_id_B 3. 3.12 1.1\n"
-        "   ...",
-    ),
-    "csv_float": dict(
-        func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
-        kwargs=[],
-        help="A text file in which is written a sequence of float numbers "
-        "separated by comma."
-        "\n\n"
-        "   utterance_id_A 12.,3.1,3.4,4.4\n"
-        "   utterance_id_B 3.,3.12,1.1\n"
-        "   ...",
-    ),
-    "text": dict(
-        func=read_2column_text,
-        kwargs=[],
-        help="Return text as is. The text must be converted to ndarray "
-        "by 'preprocess'."
-        "\n\n"
-        "   utterance_id_A hello world\n"
-        "   utterance_id_B foo bar\n"
-        "   ...",
-    ),
-    "hdf5": dict(
-        func=H5FileWrapper,
-        kwargs=[],
-        help="A HDF5 file which contains arrays at the first level or the second level."
-        "   >>> f = h5py.File('file.h5')\n"
-        "   >>> array1 = f['utterance_id_A']\n"
-        "   >>> array2 = f['utterance_id_B']\n",
-    ),
-    "rand_float": dict(
-        func=FloatRandomGenerateDataset,
-        kwargs=[],
-        help="Generate random float-ndarray which has the given shapes "
-        "in the file."
-        "\n\n"
-        "   utterance_id_A 3,4\n"
-        "   utterance_id_B 10,4\n"
-        "   ...",
-    ),
-    "rand_int_\\d+_\\d+": dict(
-        func=rand_int_loader,
-        kwargs=["loader_type"],
-        help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
-        "shapes in the path. "
-        "Give the lower and upper value by the file type. e.g. "
-        "rand_int_0_10 -> Generate integers from 0 to 10."
-        "\n\n"
-        "   utterance_id_A 3,4\n"
-        "   utterance_id_B 10,4\n"
-        "   ...",
-    ),
-}
-
-
-class AbsDataset(Dataset, ABC):
-    @abstractmethod
-    def has_name(self, name) -> bool:
-        raise NotImplementedError
-
-    @abstractmethod
-    def names(self) -> Tuple[str, ...]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
-        raise NotImplementedError
-
-
-class ESPnetDataset(AbsDataset):
+class ESPnetDataset(Dataset):
     """
-        Pytorch Dataset class for FunASR, simplied from ESPnet
+        Pytorch Dataset class for FunASR, modified from ESPnet
     """
 
     def __init__(
-        self,
-        path_name_type_list: Collection[Tuple[str, str, str]],
-        preprocess: Callable[
-            [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
-        ] = None,
-        float_dtype: str = "float32",
-        int_dtype: str = "long",
-        max_cache_size: Union[float, int, str] = 0.0,
-        max_cache_fd: int = 0,
-        dest_sample_rate: int = 16000,
+            self,
+            path_name_type_list: Collection[Tuple[str, str, str]],
+            preprocess: Callable[
+                [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
+            ] = None,
+            float_dtype: str = "float32",
+            int_dtype: str = "long",
+            dest_sample_rate: int = 16000,
     ):
         assert check_argument_types()
         if len(path_name_type_list) == 0:
@@ -289,7 +122,6 @@
 
         self.float_dtype = float_dtype
         self.int_dtype = int_dtype
-        self.max_cache_fd = max_cache_fd
         self.dest_sample_rate = dest_sample_rate
 
         self.loader_dict = {}
@@ -304,54 +136,36 @@
             if len(self.loader_dict[name]) == 0:
                 raise RuntimeError(f"{path} has no samples")
 
-            # TODO(kamo): Should check consistency of each utt-keys?
-
-        if isinstance(max_cache_size, str):
-            max_cache_size = humanfriendly.parse_size(max_cache_size)
-        self.max_cache_size = max_cache_size
-        if max_cache_size > 0:
-            self.cache = SizedDict(shared=True)
-        else:
-            self.cache = None
-
     def _build_loader(
-        self, path: str, loader_type: str
+            self, path: str, loader_type: str
     ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
         """Helper function to instantiate Loader.
 
         Args:
             path:  The file path
-            loader_type:  loader_type. sound, npy, text_int, text_float, etc
+            loader_type:  loader_type. sound, npy, text, etc
         """
-        for key, dic in DATA_TYPES.items():
-            # e.g. loader_type="sound"
-            # -> return DATA_TYPES["sound"]["func"](path)
-            if re.match(key, loader_type):
-                kwargs = {}
-                for key2 in dic["kwargs"]:
-                    if key2 == "loader_type":
-                        kwargs["loader_type"] = loader_type
-                    elif key2 == "dest_sample_rate" and loader_type=="sound":
-                        kwargs["dest_sample_rate"] = self.dest_sample_rate
-                    elif key2 == "float_dtype":
-                        kwargs["float_dtype"] = self.float_dtype
-                    elif key2 == "int_dtype":
-                        kwargs["int_dtype"] = self.int_dtype
-                    elif key2 == "max_cache_fd":
-                        kwargs["max_cache_fd"] = self.max_cache_fd
+        if loader_type == "sound":
+            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False)
+            return AdapterForSoundScpReader(loader, self.float_dtype)
+        elif loader_type == "kaldi_ark":
+            loader = kaldiio.load_scp(path)
+            return AdapterForSoundScpReader(loader, self.float_dtype)
+        elif loader_type == "npy":
+            return NpyScpReader(path)
+        elif loader_type == "text":
+            text_loader = {}
+            with open(path, "r", encoding="utf-8") as f:
+                for linenum, line in enumerate(f, 1):
+                    sps = line.rstrip().split(maxsplit=1)
+                    if len(sps) == 1:
+                        k, v = sps[0], ""
                     else:
-                        raise RuntimeError(f"Not implemented keyword argument: {key2}")
-
-                func = dic["func"]
-                try:
-                    return func(path, **kwargs)
-                except Exception:
-                    if hasattr(func, "__name__"):
-                        name = func.__name__
-                    else:
-                        name = str(func)
-                    logging.error(f"An error happened with {name}({path})")
-                    raise
+                        k, v = sps
+                    if k in text_loader:
+                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
+                    text_loader[k] = v
+            return text_loader
         else:
             raise RuntimeError(f"Not supported: loader_type={loader_type}")
 
@@ -380,10 +194,6 @@
             d = next(iter(self.loader_dict.values()))
             uid = list(d)[uid]
 
-        if self.cache is not None and uid in self.cache:
-            data = self.cache[uid]
-            return uid, data
-
         data = {}
         # 1. Load data from each loaders
         for name, loader in self.loader_dict.items():
@@ -392,7 +202,7 @@
                 if isinstance(value, (list, tuple)):
                     value = np.array(value)
                 if not isinstance(
-                    value, (np.ndarray, torch.Tensor, str, numbers.Number)
+                        value, (np.ndarray, torch.Tensor, str, numbers.Number)
                 ):
                     raise TypeError(
                         f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
@@ -433,9 +243,6 @@
             else:
                 raise NotImplementedError(f"Not supported dtype: {value.dtype}")
             data[name] = value
-
-        if self.cache is not None and self.cache.size < self.max_cache_size:
-            self.cache[uid] = data
 
         retval = uid, data
         assert check_return_type(retval)

--
Gitblit v1.9.1