From 831d00aec2434187266489a5f396d88f63709fe0 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 17 四月 2023 16:26:40 +0800
Subject: [PATCH] update

---
 funasr/utils/prepare_data.py                       |  106 ++++++++--
 funasr/bin/train.py                                |    7 
 funasr/utils/build_dataloader.py                   |   11 +
 funasr/datasets/large_datasets/build_dataloader.py |   28 -
 funasr/datasets/small_datasets/dataset.py          |  442 ++++++++++++++++++++++++++++++++++++++++++++
 5 files changed, 554 insertions(+), 40 deletions(-)

diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 7e43cca..dbfebd7 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -4,6 +4,7 @@
 
 import torch
 
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import config_argparse
 from funasr.utils.build_distributed import build_distributed
 from funasr.utils.prepare_data import prepare_data
@@ -340,4 +341,10 @@
                                                                    distributed_option.dist_rank,
                                                                    distributed_option.local_rank))
 
+    # prepare files for dataloader
     prepare_data(args, distributed_option)
+
+    set_all_random_seed(args.seed)
+    torch.backends.cudnn.enabled = args.cudnn_enabled
+    torch.backends.cudnn.benchmark = args.cudnn_benchmark
+    torch.backends.cudnn.deterministic = args.cudnn_deterministic
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 156f608..f1ec005 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -64,27 +64,17 @@
         return self.sp.DecodePieces(list(tokens))
 
 
-class ArkDataLoader(AbsIterFactory):
-    def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None,
-                 bpemodel_file=None, mode="train"):
-        symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
-        if seg_dict_file is not None:
-            seg_dict = load_seg_dict(seg_dict_file)
-        else:
-            seg_dict = None
-        if punc_dict_file is not None:
-            punc_dict = read_symbol_table(punc_dict_file)
-        else:
-            punc_dict = None
-        self.dataset_conf = dataset_conf
-        self.frontend_conf = frontend_conf
+class LargeDataLoader(AbsIterFactory):
+    def __init__(self, args, mode="train"):
+        symbol_table = read_symbol_table(args.token_list) if args.token_list is not None else None
+        seg_dict = load_seg_dict(args.seg_dict_file) if args.seg_dict_file is not None else None
+        punc_dict = load_seg_dict(args.punc_dict_file) if args.punc_dict_file is not None else None
+        bpe_tokenizer = load_seg_dict(args.bpemodel_file) if args.bpemodel_file is not None else None
+        self.dataset_conf = args.dataset_conf
+        self.frontend_conf = args.frontend_conf
         logging.info("dataloader config: {}".format(self.dataset_conf))
         batch_mode = self.dataset_conf.get("batch_mode", "padding")
-        if bpemodel_file is not None:
-            bpe_tokenizer = SentencepiecesTokenizer(bpemodel_file)
-        else:
-            bpe_tokenizer = None
-        self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
+        self.dataset = Dataset(args.data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
                                self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
 
     def build_iter(self, epoch, shuffle=True):
diff --git a/funasr/datasets/small_datasets/dataset.py b/funasr/datasets/small_datasets/dataset.py
new file mode 100644
index 0000000..7ed37fa
--- /dev/null
+++ b/funasr/datasets/small_datasets/dataset.py
@@ -0,0 +1,442 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+from abc import ABC
+from abc import abstractmethod
+import collections
+import copy
+import functools
+import logging
+import numbers
+import re
+from typing import Any
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import Mapping
+from typing import Tuple
+from typing import Union
+
+import h5py
+import humanfriendly
+import kaldiio
+import numpy as np
+import torch
+from torch.utils.data.dataset import Dataset
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.npy_scp import NpyScpReader
+from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
+from funasr.fileio.rand_gen_dataset import IntRandomGenerateDataset
+from funasr.fileio.read_text import load_num_sequence_text
+from funasr.fileio.read_text import read_2column_text
+from funasr.fileio.sound_scp import SoundScpReader
+from funasr.utils.sized_dict import SizedDict
+
+
+class AdapterForSoundScpReader(collections.abc.Mapping):
+    def __init__(self, loader, dtype=None):
+        assert check_argument_types()
+        self.loader = loader
+        self.dtype = dtype
+        self.rate = None
+
+    def keys(self):
+        return self.loader.keys()
+
+    def __len__(self):
+        return len(self.loader)
+
+    def __iter__(self):
+        return iter(self.loader)
+
+    def __getitem__(self, key: str) -> np.ndarray:
+        retval = self.loader[key]
+
+        if isinstance(retval, tuple):
+            assert len(retval) == 2, len(retval)
+            if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
+                # sound scp case
+                rate, array = retval
+            elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
+                # Extended ark format case
+                array, rate = retval
+            else:
+                raise RuntimeError(
+                    f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
+                )
+
+            if self.rate is not None and self.rate != rate:
+                raise RuntimeError(
+                    f"Sampling rates are mismatched: {self.rate} != {rate}"
+                )
+            self.rate = rate
+            # Multichannel wave fie
+            # array: (NSample, Channel) or (Nsample)
+            if self.dtype is not None:
+                array = array.astype(self.dtype)
+
+        else:
+            # Normal ark case
+            assert isinstance(retval, np.ndarray), type(retval)
+            array = retval
+            if self.dtype is not None:
+                array = array.astype(self.dtype)
+
+        assert isinstance(array, np.ndarray), type(array)
+        return array
+
+
+class H5FileWrapper:
+    def __init__(self, path: str):
+        self.path = path
+        self.h5_file = h5py.File(path, "r")
+
+    def __repr__(self) -> str:
+        return str(self.h5_file)
+
+    def __len__(self) -> int:
+        return len(self.h5_file)
+
+    def __iter__(self):
+        return iter(self.h5_file)
+
+    def __getitem__(self, key) -> np.ndarray:
+        value = self.h5_file[key]
+        return value[()]
+
+
+def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
+    # The file is as follows:
+    #   utterance_id_A /some/where/a.wav
+    #   utterance_id_B /some/where/a.flac
+
+    # NOTE(kamo): SoundScpReader doesn't support pipe-fashion
+    # like Kaldi e.g. "cat a.wav |".
+    # NOTE(kamo): The audio signal is normalized to [-1,1] range.
+    loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
+
+    # SoundScpReader.__getitem__() returns Tuple[int, ndarray],
+    # but ndarray is desired, so Adapter class is inserted here
+    return AdapterForSoundScpReader(loader, float_dtype)
+
+
+def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
+    loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
+    return AdapterForSoundScpReader(loader, float_dtype)
+
+
+def rand_int_loader(filepath, loader_type):
+    # e.g. rand_int_3_10
+    try:
+        low, high = map(int, loader_type[len("rand_int_") :].split("_"))
+    except ValueError:
+        raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
+    return IntRandomGenerateDataset(filepath, low, high)
+
+
+DATA_TYPES = {
+    "sound": dict(
+        func=sound_loader,
+        kwargs=["dest_sample_rate","float_dtype"],
+        help="Audio format types which supported by sndfile wav, flac, etc."
+        "\n\n"
+        "   utterance_id_a a.wav\n"
+        "   utterance_id_b b.wav\n"
+        "   ...",
+    ),
+    "kaldi_ark": dict(
+        func=kaldi_loader,
+        kwargs=["max_cache_fd"],
+        help="Kaldi-ark file type."
+        "\n\n"
+        "   utterance_id_A /some/where/a.ark:123\n"
+        "   utterance_id_B /some/where/a.ark:456\n"
+        "   ...",
+    ),
+    "npy": dict(
+        func=NpyScpReader,
+        kwargs=[],
+        help="Npy file format."
+        "\n\n"
+        "   utterance_id_A /some/where/a.npy\n"
+        "   utterance_id_B /some/where/b.npy\n"
+        "   ...",
+    ),
+    "text_int": dict(
+        func=functools.partial(load_num_sequence_text, loader_type="text_int"),
+        kwargs=[],
+        help="A text file in which is written a sequence of interger numbers "
+        "separated by space."
+        "\n\n"
+        "   utterance_id_A 12 0 1 3\n"
+        "   utterance_id_B 3 3 1\n"
+        "   ...",
+    ),
+    "csv_int": dict(
+        func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
+        kwargs=[],
+        help="A text file in which is written a sequence of interger numbers "
+        "separated by comma."
+        "\n\n"
+        "   utterance_id_A 100,80\n"
+        "   utterance_id_B 143,80\n"
+        "   ...",
+    ),
+    "text_float": dict(
+        func=functools.partial(load_num_sequence_text, loader_type="text_float"),
+        kwargs=[],
+        help="A text file in which is written a sequence of float numbers "
+        "separated by space."
+        "\n\n"
+        "   utterance_id_A 12. 3.1 3.4 4.4\n"
+        "   utterance_id_B 3. 3.12 1.1\n"
+        "   ...",
+    ),
+    "csv_float": dict(
+        func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
+        kwargs=[],
+        help="A text file in which is written a sequence of float numbers "
+        "separated by comma."
+        "\n\n"
+        "   utterance_id_A 12.,3.1,3.4,4.4\n"
+        "   utterance_id_B 3.,3.12,1.1\n"
+        "   ...",
+    ),
+    "text": dict(
+        func=read_2column_text,
+        kwargs=[],
+        help="Return text as is. The text must be converted to ndarray "
+        "by 'preprocess'."
+        "\n\n"
+        "   utterance_id_A hello world\n"
+        "   utterance_id_B foo bar\n"
+        "   ...",
+    ),
+    "hdf5": dict(
+        func=H5FileWrapper,
+        kwargs=[],
+        help="A HDF5 file which contains arrays at the first level or the second level."
+        "   >>> f = h5py.File('file.h5')\n"
+        "   >>> array1 = f['utterance_id_A']\n"
+        "   >>> array2 = f['utterance_id_B']\n",
+    ),
+    "rand_float": dict(
+        func=FloatRandomGenerateDataset,
+        kwargs=[],
+        help="Generate random float-ndarray which has the given shapes "
+        "in the file."
+        "\n\n"
+        "   utterance_id_A 3,4\n"
+        "   utterance_id_B 10,4\n"
+        "   ...",
+    ),
+    "rand_int_\\d+_\\d+": dict(
+        func=rand_int_loader,
+        kwargs=["loader_type"],
+        help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
+        "shapes in the path. "
+        "Give the lower and upper value by the file type. e.g. "
+        "rand_int_0_10 -> Generate integers from 0 to 10."
+        "\n\n"
+        "   utterance_id_A 3,4\n"
+        "   utterance_id_B 10,4\n"
+        "   ...",
+    ),
+}
+
+
+class AbsDataset(Dataset, ABC):
+    @abstractmethod
+    def has_name(self, name) -> bool:
+        raise NotImplementedError
+
+    @abstractmethod
+    def names(self) -> Tuple[str, ...]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
+        raise NotImplementedError
+
+
+class ESPnetDataset(AbsDataset):
+    """
+        Pytorch Dataset class for FunASR, simplied from ESPnet
+    """
+
+    def __init__(
+        self,
+        path_name_type_list: Collection[Tuple[str, str, str]],
+        preprocess: Callable[
+            [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
+        ] = None,
+        float_dtype: str = "float32",
+        int_dtype: str = "long",
+        max_cache_size: Union[float, int, str] = 0.0,
+        max_cache_fd: int = 0,
+        dest_sample_rate: int = 16000,
+    ):
+        assert check_argument_types()
+        if len(path_name_type_list) == 0:
+            raise ValueError(
+                '1 or more elements are required for "path_name_type_list"'
+            )
+
+        path_name_type_list = copy.deepcopy(path_name_type_list)
+        self.preprocess = preprocess
+
+        self.float_dtype = float_dtype
+        self.int_dtype = int_dtype
+        self.max_cache_fd = max_cache_fd
+        self.dest_sample_rate = dest_sample_rate
+
+        self.loader_dict = {}
+        self.debug_info = {}
+        for path, name, _type in path_name_type_list:
+            if name in self.loader_dict:
+                raise RuntimeError(f'"{name}" is duplicated for data-key')
+
+            loader = self._build_loader(path, _type)
+            self.loader_dict[name] = loader
+            self.debug_info[name] = path, _type
+            if len(self.loader_dict[name]) == 0:
+                raise RuntimeError(f"{path} has no samples")
+
+            # TODO(kamo): Should check consistency of each utt-keys?
+
+        if isinstance(max_cache_size, str):
+            max_cache_size = humanfriendly.parse_size(max_cache_size)
+        self.max_cache_size = max_cache_size
+        if max_cache_size > 0:
+            self.cache = SizedDict(shared=True)
+        else:
+            self.cache = None
+
+    def _build_loader(
+        self, path: str, loader_type: str
+    ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
+        """Helper function to instantiate Loader.
+
+        Args:
+            path:  The file path
+            loader_type:  loader_type. sound, npy, text_int, text_float, etc
+        """
+        for key, dic in DATA_TYPES.items():
+            # e.g. loader_type="sound"
+            # -> return DATA_TYPES["sound"]["func"](path)
+            if re.match(key, loader_type):
+                kwargs = {}
+                for key2 in dic["kwargs"]:
+                    if key2 == "loader_type":
+                        kwargs["loader_type"] = loader_type
+                    elif key2 == "dest_sample_rate" and loader_type=="sound":
+                        kwargs["dest_sample_rate"] = self.dest_sample_rate
+                    elif key2 == "float_dtype":
+                        kwargs["float_dtype"] = self.float_dtype
+                    elif key2 == "int_dtype":
+                        kwargs["int_dtype"] = self.int_dtype
+                    elif key2 == "max_cache_fd":
+                        kwargs["max_cache_fd"] = self.max_cache_fd
+                    else:
+                        raise RuntimeError(f"Not implemented keyword argument: {key2}")
+
+                func = dic["func"]
+                try:
+                    return func(path, **kwargs)
+                except Exception:
+                    if hasattr(func, "__name__"):
+                        name = func.__name__
+                    else:
+                        name = str(func)
+                    logging.error(f"An error happened with {name}({path})")
+                    raise
+        else:
+            raise RuntimeError(f"Not supported: loader_type={loader_type}")
+
+    def has_name(self, name) -> bool:
+        return name in self.loader_dict
+
+    def names(self) -> Tuple[str, ...]:
+        return tuple(self.loader_dict)
+
+    def __iter__(self):
+        return iter(next(iter(self.loader_dict.values())))
+
+    def __repr__(self):
+        _mes = self.__class__.__name__
+        _mes += "("
+        for name, (path, _type) in self.debug_info.items():
+            _mes += f'\n  {name}: {{"path": "{path}", "type": "{_type}"}}'
+        _mes += f"\n  preprocess: {self.preprocess})"
+        return _mes
+
+    def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
+        assert check_argument_types()
+
+        # Change integer-id to string-id
+        if isinstance(uid, int):
+            d = next(iter(self.loader_dict.values()))
+            uid = list(d)[uid]
+
+        if self.cache is not None and uid in self.cache:
+            data = self.cache[uid]
+            return uid, data
+
+        data = {}
+        # 1. Load data from each loaders
+        for name, loader in self.loader_dict.items():
+            try:
+                value = loader[uid]
+                if isinstance(value, (list, tuple)):
+                    value = np.array(value)
+                if not isinstance(
+                    value, (np.ndarray, torch.Tensor, str, numbers.Number)
+                ):
+                    raise TypeError(
+                        f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
+                    )
+            except Exception:
+                path, _type = self.debug_info[name]
+                logging.error(
+                    f"Error happened with path={path}, type={_type}, id={uid}"
+                )
+                raise
+
+            # torch.Tensor is converted to ndarray
+            if isinstance(value, torch.Tensor):
+                value = value.numpy()
+            elif isinstance(value, numbers.Number):
+                value = np.array([value])
+            data[name] = value
+
+        # 2. [Option] Apply preprocessing
+        #   e.g. funasr.train.preprocessor:CommonPreprocessor
+        if self.preprocess is not None:
+            data = self.preprocess(uid, data)
+
+        # 3. Force data-precision
+        for name in data:
+            value = data[name]
+            if not isinstance(value, np.ndarray):
+                raise RuntimeError(
+                    f"All values must be converted to np.ndarray object "
+                    f'by preprocessing, but "{name}" is still {type(value)}.'
+                )
+
+            # Cast to desired type
+            if value.dtype.kind == "f":
+                value = value.astype(self.float_dtype)
+            elif value.dtype.kind == "i":
+                value = value.astype(self.int_dtype)
+            else:
+                raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+            data[name] = value
+
+        if self.cache is not None and self.cache.size < self.max_cache_size:
+            self.cache[uid] = data
+
+        retval = uid, data
+        assert check_return_type(retval)
+        return retval
diff --git a/funasr/utils/build_dataloader.py b/funasr/utils/build_dataloader.py
new file mode 100644
index 0000000..59b19ba
--- /dev/null
+++ b/funasr/utils/build_dataloader.py
@@ -0,0 +1,11 @@
+from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
+
+
+def build_dataloader(args):
+    if args.dataset_type == "small":
+        pass
+    elif args.dataset_type == "large":
+        train_iter_factory = LargeDataLoader(args, mode="train")
+        valid_iter_factory = LargeDataLoader(args,  mode="valid")
+    else:
+        raise ValueError(f"Not supported dataset_type={args.dataset_type}")
\ No newline at end of file
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
index a0d97f6..c9a99e5 100644
--- a/funasr/utils/prepare_data.py
+++ b/funasr/utils/prepare_data.py
@@ -1,9 +1,11 @@
-import os
 import logging
+import os
+import shutil
 from multiprocessing import Pool
 
 import numpy as np
 import torch.distributed as dist
+import torchaudio
 
 
 def filter_wav_text(data_dir, dataset):
@@ -34,25 +36,37 @@
                 f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
             else:
                 filter_count += 1
-    logging.info("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines),
-                                                                                                     filter_count,
-                                                                                                     dataset))
+    logging.info(
+        "{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines),
+                                                                                                   filter_count,
+                                                                                                   dataset))
 
 
-def calc_shape_core(root_path, frontend_conf, speech_length_min, speech_length_max, idx):
+def wav2num_frame(wav_path, frontend_conf):
+    waveform, sampling_rate = torchaudio.load(wav_path)
+    n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
+    feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
+    return n_frames, feature_dim
+
+
+def calc_shape_core(root_path, args, idx):
     wav_scp_file = os.path.join(root_path, "wav.scp.{}".format(idx))
     shape_file = os.path.join(root_path, "speech_shape.{}".format(idx))
     with open(wav_scp_file) as f:
         lines = f.readlines()
+    frontend_conf = args.frontend_conf
+    dataset_conf = args.dataset_conf
+    speech_length_min = dataset_conf.speech_length_min if hasattr(dataset_conf, "speech_length_min") else -1
+    speech_length_max = dataset_conf.speech_length_max if hasattr(dataset_conf, "speech_length_max") else -1
     with open(shape_file, "w") as f:
         for line in lines:
             sample_name, wav_path = line.strip().split()
-            n_frames, feature_dim, speech_length = wav2num_frame(wav_path, frontend_conf)
+            n_frames, feature_dim = wav2num_frame(wav_path, frontend_conf)
             write_flag = True
-            if speech_length_min > 0 and speech_length < speech_length_min:
-                write_flag = False
-            if speech_length_max > 0 and speech_length > speech_length_max:
-                write_flag = False
+            if n_frames > 0 and speech_length_min > 0:
+                write_flag = n_frames >= speech_length_min
+            if n_frames > 0 and speech_length_max > 0:
+                write_flag = n_frames <= speech_length_max
             if write_flag:
                 f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
                 f.flush()
@@ -61,12 +75,13 @@
 def calc_shape(args, dataset, nj=32):
     shape_path = os.path.join(args.data_dir, dataset, "speech_shape")
     if os.path.exists(shape_path):
-        print('Shape file for small dataset already exists.')
+        logging.info('Shape file for small dataset already exists.')
         return
 
     split_shape_path = os.path.join(args.data_dir, dataset, "shape_files")
-    if os.path
-    os.makedirs(split_shape_path, exist_ok=True)
+    if os.path.exists(split_shape_path):
+        shutil.rmtree(split_shape_path)
+    os.mkdir(split_shape_path)
 
     # split
     wav_scp_file = os.path.join(args.data_dir, dataset, "wav.scp")
@@ -87,21 +102,58 @@
 
     p = Pool(nj)
     for i in range(nj):
-        p.apply_async(calc_shape_core,
-                      args=(shape_path, frontend_conf, speech_length_min, speech_length_max, str(i + 1)))
-    print('Generating shape files, please wait a few minutes...')
+        p.apply_async(calc_shape_core, args=(split_shape_path, args, str(i + 1)))
+    logging.info("Generating shape files, please wait a few minutes...")
     p.close()
     p.join()
 
     # combine
-    file = os.path.join(data_dir, dataset, "speech_shape")
-    with open(file, "w") as f:
+    with open(shape_path, "w") as f:
         for i in range(nj):
-            job_file = os.path.join(shape_path, "speech_shape.{}".format(str(i + 1)))
+            job_file = os.path.join(split_shape_path, "speech_shape.{}".format(str(i + 1)))
             with open(job_file) as job_f:
                 lines = job_f.readlines()
                 f.writelines(lines)
-    print('Generating shape files done.')
+    logging.info('Generating shape files done.')
+
+
+def generate_data_list(data_dir, dataset, nj=100):
+    list_file = os.path.join(data_dir, dataset, "data.list")
+    if os.path.exists(list_file):
+        logging.info('Data list for large dataset already exists.')
+        return
+    split_path = os.path.join(data_dir, dataset, "split")
+    if os.path.exists(split_path):
+        shutil.rmtree(split_path)
+    os.mkdir(split_path)
+
+    with open(os.path.join(data_dir, dataset, "wav.scp")) as f_wav:
+        wav_lines = f_wav.readlines()
+    with open(os.path.join(data_dir, dataset, "text")) as f_text:
+        text_lines = f_text.readlines()
+    num_lines = len(wav_lines)
+    num_job_lines = num_lines // nj
+    start = 0
+    for i in range(nj):
+        end = start + num_job_lines
+        split_path_nj = os.path.join(split_path, str(i + 1))
+        os.mkdir(split_path_nj)
+        wav_file = os.path.join(split_path_nj, "wav.scp")
+        text_file = os.path.join(split_path_nj, "text")
+        with open(wav_file, "w") as fw, open(text_file, "w") as ft:
+            if i == nj - 1:
+                fw.writelines(wav_lines[start:])
+                ft.writelines(text_lines[start:])
+            else:
+                fw.writelines(wav_lines[start:end])
+                ft.writelines(text_lines[start:end])
+        start = end
+
+    with open(list_file, "w") as f_data:
+        for i in range(nj):
+            wav_path = os.path.join(split_path, str(i + 1), "wav.scp")
+            text_path = os.path.join(split_path, str(i + 1), "text")
+            f_data.write(wav_path + " " + text_path + "\n")
 
 
 def prepare_data(args, distributed_option):
@@ -109,6 +161,18 @@
     if not distributed or distributed_option.dist_rank == 0:
         filter_wav_text(args.data_dir, args.train_set)
         filter_wav_text(args.data_dir, args.dev_set)
-        dist.barrier()
 
         if args.dataset_type == "small" and args.train_shape_file is None:
+            calc_shape(args, args.train_set)
+            calc_shape(args, args.dev_set)
+
+        if args.dataset_type == "large" and args.train_data_file is None:
+            generate_data_list(args.data_dir, args.train_set)
+            generate_data_list(args.data_dir, args.dev_set)
+
+    args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
+    args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")]
+    args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
+    args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list")
+    if distributed:
+        dist.barrier()

--
Gitblit v1.9.1