From 9f90bad3f58c86143e630a9d11d8434adaa62904 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 17 四月 2023 17:11:05 +0800
Subject: [PATCH] update
---
/dev/null | 349 --------------
funasr/datasets/small_datasets/preprocessor.py | 826 ++++++++++++++++++++++++++++++++++
funasr/bin/train.py | 1
funasr/datasets/small_datasets/build_loader.py | 16
funasr/datasets/small_datasets/dataset.py | 243 +--------
5 files changed, 878 insertions(+), 557 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index dbfebd7..9b93820 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -25,6 +25,7 @@
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
# ddp related
parser.add_argument(
diff --git a/funasr/datasets/iterable_dataset_modelscope.py b/funasr/datasets/iterable_dataset_modelscope.py
deleted file mode 100644
index 860492c..0000000
--- a/funasr/datasets/iterable_dataset_modelscope.py
+++ /dev/null
@@ -1,349 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-# Part of the implementation is borrowed from espnet/espnet.
-"""Iterable dataset module."""
-import copy
-from io import StringIO
-from pathlib import Path
-from typing import Callable, Collection, Dict, Iterator, Tuple, Union
-
-import kaldiio
-import numpy as np
-import soundfile
-import torch
-from funasr.datasets.dataset import ESPnetDataset
-from torch.utils.data.dataset import IterableDataset
-from typeguard import check_argument_types
-
-from funasr.utils import wav_utils
-
-
-def load_kaldi(input):
- retval = kaldiio.load_mat(input)
- if isinstance(retval, tuple):
- assert len(retval) == 2, len(retval)
- if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
- # sound scp case
- rate, array = retval
- elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
- # Extended ark format case
- array, rate = retval
- else:
- raise RuntimeError(
- f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
-
- # Multichannel wave fie
- # array: (NSample, Channel) or (Nsample)
-
- else:
- # Normal ark case
- assert isinstance(retval, np.ndarray), type(retval)
- array = retval
- return array
-
-
-DATA_TYPES = {
- 'sound':
- lambda x: soundfile.read(x)[0],
- 'kaldi_ark':
- load_kaldi,
- 'npy':
- np.load,
- 'text_int':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
- 'csv_int':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
- 'text_float':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
- ),
- 'csv_float':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
- ),
- 'text':
- lambda x: x,
-}
-
-
-class IterableESPnetDatasetModelScope(IterableDataset):
- """Pytorch Dataset class for ESPNet.
-
- Examples:
- >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
- ... ('token_int', 'output', 'text_int')],
- ... )
- >>> for uid, data in dataset:
- ... data
- {'input': per_utt_array, 'output': per_utt_array}
- """
- def __init__(self,
- path_name_type_list: Collection[Tuple[any, str, str]],
- preprocess: Callable[[str, Dict[str, np.ndarray]],
- Dict[str, np.ndarray]] = None,
- float_dtype: str = 'float32',
- int_dtype: str = 'long',
- key_file: str = None,
- sample_rate: Union[dict, int] = 16000):
- assert check_argument_types()
- if len(path_name_type_list) == 0:
- raise ValueError(
- '1 or more elements are required for "path_name_type_list"')
-
- self.preprocess = preprocess
-
- self.float_dtype = float_dtype
- self.int_dtype = int_dtype
- self.key_file = key_file
- self.sample_rate = sample_rate
-
- self.debug_info = {}
- non_iterable_list = []
- self.path_name_type_list = []
-
- path_list = path_name_type_list[0]
- name = path_name_type_list[1]
- _type = path_name_type_list[2]
- if name in self.debug_info:
- raise RuntimeError(f'"{name}" is duplicated for data-key')
- self.debug_info[name] = path_list, _type
- # for path, name, _type in path_name_type_list:
- for path in path_list:
- self.path_name_type_list.append((path, name, _type))
-
- if len(non_iterable_list) != 0:
- # Some types doesn't support iterable mode
- self.non_iterable_dataset = ESPnetDataset(
- path_name_type_list=non_iterable_list,
- preprocess=preprocess,
- float_dtype=float_dtype,
- int_dtype=int_dtype,
- )
- else:
- self.non_iterable_dataset = None
-
- self.apply_utt2category = False
-
- def has_name(self, name) -> bool:
- return name in self.debug_info
-
- def names(self) -> Tuple[str, ...]:
- return tuple(self.debug_info)
-
- def __repr__(self):
- _mes = self.__class__.__name__
- _mes += '('
- for name, (path, _type) in self.debug_info.items():
- _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
- _mes += f'\n preprocess: {self.preprocess})'
- return _mes
-
- def __iter__(
- self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
- torch.set_printoptions(profile='default')
- count = len(self.path_name_type_list)
- for idx in range(count):
- # 2. Load the entry from each line and create a dict
- data = {}
- # 2.a. Load data streamingly
-
- # value: /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
- value = self.path_name_type_list[idx][0]['file']
- uid = self.path_name_type_list[idx][0]['key']
- # name: speech
- name = self.path_name_type_list[idx][1]
- _type = self.path_name_type_list[idx][2]
- func = DATA_TYPES[_type]
- array = func(value)
-
- # 2.b. audio resample
- if _type == 'sound':
- audio_sr: int = 16000
- model_sr: int = 16000
- if isinstance(self.sample_rate, int):
- model_sr = self.sample_rate
- else:
- if 'audio_sr' in self.sample_rate:
- audio_sr = self.sample_rate['audio_sr']
- if 'model_sr' in self.sample_rate:
- model_sr = self.sample_rate['model_sr']
- array = wav_utils.torch_resample(array, audio_sr, model_sr)
-
- # array: [ 1.25122070e-03 ... ]
- data[name] = array
-
- # 3. [Option] Apply preprocessing
- # e.g. espnet2.train.preprocessor:CommonPreprocessor
- if self.preprocess is not None:
- data = self.preprocess(uid, data)
- # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
-
- # 4. Force data-precision
- for name in data:
- # value is np.ndarray data
- value = data[name]
- if not isinstance(value, np.ndarray):
- raise RuntimeError(
- f'All values must be converted to np.ndarray object '
- f'by preprocessing, but "{name}" is still {type(value)}.'
- )
-
- # Cast to desired type
- if value.dtype.kind == 'f':
- value = value.astype(self.float_dtype)
- elif value.dtype.kind == 'i':
- value = value.astype(self.int_dtype)
- else:
- raise NotImplementedError(
- f'Not supported dtype: {value.dtype}')
- data[name] = value
-
- yield uid, data
-
- if count == 0:
- raise RuntimeError('No iteration')
-
-
-class IterableESPnetBytesModelScope(IterableDataset):
- """Pytorch audio bytes class for ESPNet.
-
- Examples:
- >>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
- ... ('token_int', 'output', 'text_int')],
- ... )
- >>> for uid, data in dataset:
- ... data
- {'input': per_utt_array, 'output': per_utt_array}
- """
- def __init__(self,
- path_name_type_list: Collection[Tuple[any, str, str]],
- preprocess: Callable[[str, Dict[str, np.ndarray]],
- Dict[str, np.ndarray]] = None,
- float_dtype: str = 'float32',
- int_dtype: str = 'long',
- key_file: str = None,
- sample_rate: Union[dict, int] = 16000):
- assert check_argument_types()
- if len(path_name_type_list) == 0:
- raise ValueError(
- '1 or more elements are required for "path_name_type_list"')
-
- self.preprocess = preprocess
-
- self.float_dtype = float_dtype
- self.int_dtype = int_dtype
- self.key_file = key_file
- self.sample_rate = sample_rate
-
- self.debug_info = {}
- non_iterable_list = []
- self.path_name_type_list = []
-
- audio_data = path_name_type_list[0]
- name = path_name_type_list[1]
- _type = path_name_type_list[2]
- if name in self.debug_info:
- raise RuntimeError(f'"{name}" is duplicated for data-key')
- self.debug_info[name] = audio_data, _type
- self.path_name_type_list.append((audio_data, name, _type))
-
- if len(non_iterable_list) != 0:
- # Some types doesn't support iterable mode
- self.non_iterable_dataset = ESPnetDataset(
- path_name_type_list=non_iterable_list,
- preprocess=preprocess,
- float_dtype=float_dtype,
- int_dtype=int_dtype,
- )
- else:
- self.non_iterable_dataset = None
-
- self.apply_utt2category = False
-
- if float_dtype == 'float32':
- self.np_dtype = np.float32
-
- def has_name(self, name) -> bool:
- return name in self.debug_info
-
- def names(self) -> Tuple[str, ...]:
- return tuple(self.debug_info)
-
- def __repr__(self):
- _mes = self.__class__.__name__
- _mes += '('
- for name, (path, _type) in self.debug_info.items():
- _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
- _mes += f'\n preprocess: {self.preprocess})'
- return _mes
-
- def __iter__(
- self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
-
- torch.set_printoptions(profile='default')
- # 2. Load the entry from each line and create a dict
- data = {}
- # 2.a. Load data streamingly
-
- value = self.path_name_type_list[0][0]
- uid = 'pcm_data'
- # name: speech
- name = self.path_name_type_list[0][1]
- _type = self.path_name_type_list[0][2]
- func = DATA_TYPES[_type]
- # array: [ 1.25122070e-03 ... ]
- # data[name] = np.frombuffer(value, dtype=self.np_dtype)
-
- # 2.b. byte(PCM16) to float32
- middle_data = np.frombuffer(value, dtype=np.int16)
- middle_data = np.asarray(middle_data)
- if middle_data.dtype.kind not in 'iu':
- raise TypeError("'middle_data' must be an array of integers")
- dtype = np.dtype('float32')
- if dtype.kind != 'f':
- raise TypeError("'dtype' must be a floating point type")
-
- i = np.iinfo(middle_data.dtype)
- abs_max = 2**(i.bits - 1)
- offset = i.min + abs_max
- array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
- dtype=self.np_dtype)
-
- # 2.c. audio resample
- if _type == 'sound':
- audio_sr: int = 16000
- model_sr: int = 16000
- if isinstance(self.sample_rate, int):
- model_sr = self.sample_rate
- else:
- if 'audio_sr' in self.sample_rate:
- audio_sr = self.sample_rate['audio_sr']
- if 'model_sr' in self.sample_rate:
- model_sr = self.sample_rate['model_sr']
- array = wav_utils.torch_resample(array, audio_sr, model_sr)
-
- data[name] = array
-
- # 3. [Option] Apply preprocessing
- # e.g. espnet2.train.preprocessor:CommonPreprocessor
- if self.preprocess is not None:
- data = self.preprocess(uid, data)
- # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
-
- # 4. Force data-precision
- for name in data:
- # value is np.ndarray data
- value = data[name]
- if not isinstance(value, np.ndarray):
- raise RuntimeError(
- f'All values must be converted to np.ndarray object '
- f'by preprocessing, but "{name}" is still {type(value)}.')
-
- # Cast to desired type
- if value.dtype.kind == 'f':
- value = value.astype(self.float_dtype)
- elif value.dtype.kind == 'i':
- value = value.astype(self.int_dtype)
- else:
- raise NotImplementedError(
- f'Not supported dtype: {value.dtype}')
- data[name] = value
-
- yield uid, data
diff --git a/funasr/datasets/small_datasets/build_loader.py b/funasr/datasets/small_datasets/build_loader.py
new file mode 100644
index 0000000..012113f
--- /dev/null
+++ b/funasr/datasets/small_datasets/build_loader.py
@@ -0,0 +1,16 @@
+import torch
+from funasr.datasets.small_datasets.dataset import ESPnetDataset
+from funasr.datasets.small_datasets.build_preprocess import build_preprocess
+
+def build_dataloader(args):
+ if args.frontend_conf is not None:
+ dest_sample_rate = args.frontend_conf["fs"] if (args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
+ preprocess_fn = build_preprocess()
+ dataset = ESPnetDataset(
+ iter_options.data_path_and_name_and_type,
+ float_dtype=args.train_dtype,
+ preprocess=preprocess_fn,
+ max_cache_size=iter_options.max_cache_size,
+ max_cache_fd=iter_options.max_cache_fd,
+ dest_sample_rate=dest_sample_rate,
+ )
diff --git a/funasr/datasets/small_datasets/dataset.py b/funasr/datasets/small_datasets/dataset.py
index 7ed37fa..9bf0630 100644
--- a/funasr/datasets/small_datasets/dataset.py
+++ b/funasr/datasets/small_datasets/dataset.py
@@ -1,15 +1,10 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-from abc import ABC
-from abc import abstractmethod
import collections
import copy
-import functools
import logging
import numbers
-import re
-from typing import Any
from typing import Callable
from typing import Collection
from typing import Dict
@@ -17,7 +12,6 @@
from typing import Tuple
from typing import Union
-import h5py
import humanfriendly
import kaldiio
import numpy as np
@@ -27,10 +21,6 @@
from typeguard import check_return_type
from funasr.fileio.npy_scp import NpyScpReader
-from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
-from funasr.fileio.rand_gen_dataset import IntRandomGenerateDataset
-from funasr.fileio.read_text import load_num_sequence_text
-from funasr.fileio.read_text import read_2column_text
from funasr.fileio.sound_scp import SoundScpReader
from funasr.utils.sized_dict import SizedDict
@@ -88,25 +78,6 @@
return array
-class H5FileWrapper:
- def __init__(self, path: str):
- self.path = path
- self.h5_file = h5py.File(path, "r")
-
- def __repr__(self) -> str:
- return str(self.h5_file)
-
- def __len__(self) -> int:
- return len(self.h5_file)
-
- def __iter__(self):
- return iter(self.h5_file)
-
- def __getitem__(self, key) -> np.ndarray:
- value = self.h5_file[key]
- return value[()]
-
-
def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
# The file is as follows:
# utterance_id_A /some/where/a.wav
@@ -127,156 +98,22 @@
return AdapterForSoundScpReader(loader, float_dtype)
-def rand_int_loader(filepath, loader_type):
- # e.g. rand_int_3_10
- try:
- low, high = map(int, loader_type[len("rand_int_") :].split("_"))
- except ValueError:
- raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
- return IntRandomGenerateDataset(filepath, low, high)
-
-
-DATA_TYPES = {
- "sound": dict(
- func=sound_loader,
- kwargs=["dest_sample_rate","float_dtype"],
- help="Audio format types which supported by sndfile wav, flac, etc."
- "\n\n"
- " utterance_id_a a.wav\n"
- " utterance_id_b b.wav\n"
- " ...",
- ),
- "kaldi_ark": dict(
- func=kaldi_loader,
- kwargs=["max_cache_fd"],
- help="Kaldi-ark file type."
- "\n\n"
- " utterance_id_A /some/where/a.ark:123\n"
- " utterance_id_B /some/where/a.ark:456\n"
- " ...",
- ),
- "npy": dict(
- func=NpyScpReader,
- kwargs=[],
- help="Npy file format."
- "\n\n"
- " utterance_id_A /some/where/a.npy\n"
- " utterance_id_B /some/where/b.npy\n"
- " ...",
- ),
- "text_int": dict(
- func=functools.partial(load_num_sequence_text, loader_type="text_int"),
- kwargs=[],
- help="A text file in which is written a sequence of interger numbers "
- "separated by space."
- "\n\n"
- " utterance_id_A 12 0 1 3\n"
- " utterance_id_B 3 3 1\n"
- " ...",
- ),
- "csv_int": dict(
- func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
- kwargs=[],
- help="A text file in which is written a sequence of interger numbers "
- "separated by comma."
- "\n\n"
- " utterance_id_A 100,80\n"
- " utterance_id_B 143,80\n"
- " ...",
- ),
- "text_float": dict(
- func=functools.partial(load_num_sequence_text, loader_type="text_float"),
- kwargs=[],
- help="A text file in which is written a sequence of float numbers "
- "separated by space."
- "\n\n"
- " utterance_id_A 12. 3.1 3.4 4.4\n"
- " utterance_id_B 3. 3.12 1.1\n"
- " ...",
- ),
- "csv_float": dict(
- func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
- kwargs=[],
- help="A text file in which is written a sequence of float numbers "
- "separated by comma."
- "\n\n"
- " utterance_id_A 12.,3.1,3.4,4.4\n"
- " utterance_id_B 3.,3.12,1.1\n"
- " ...",
- ),
- "text": dict(
- func=read_2column_text,
- kwargs=[],
- help="Return text as is. The text must be converted to ndarray "
- "by 'preprocess'."
- "\n\n"
- " utterance_id_A hello world\n"
- " utterance_id_B foo bar\n"
- " ...",
- ),
- "hdf5": dict(
- func=H5FileWrapper,
- kwargs=[],
- help="A HDF5 file which contains arrays at the first level or the second level."
- " >>> f = h5py.File('file.h5')\n"
- " >>> array1 = f['utterance_id_A']\n"
- " >>> array2 = f['utterance_id_B']\n",
- ),
- "rand_float": dict(
- func=FloatRandomGenerateDataset,
- kwargs=[],
- help="Generate random float-ndarray which has the given shapes "
- "in the file."
- "\n\n"
- " utterance_id_A 3,4\n"
- " utterance_id_B 10,4\n"
- " ...",
- ),
- "rand_int_\\d+_\\d+": dict(
- func=rand_int_loader,
- kwargs=["loader_type"],
- help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
- "shapes in the path. "
- "Give the lower and upper value by the file type. e.g. "
- "rand_int_0_10 -> Generate integers from 0 to 10."
- "\n\n"
- " utterance_id_A 3,4\n"
- " utterance_id_B 10,4\n"
- " ...",
- ),
-}
-
-
-class AbsDataset(Dataset, ABC):
- @abstractmethod
- def has_name(self, name) -> bool:
- raise NotImplementedError
-
- @abstractmethod
- def names(self) -> Tuple[str, ...]:
- raise NotImplementedError
-
- @abstractmethod
- def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
- raise NotImplementedError
-
-
-class ESPnetDataset(AbsDataset):
+class ESPnetDataset(Dataset):
"""
- Pytorch Dataset class for FunASR, simplied from ESPnet
+ Pytorch Dataset class for FunASR, modified from ESPnet
"""
def __init__(
- self,
- path_name_type_list: Collection[Tuple[str, str, str]],
- preprocess: Callable[
- [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
- ] = None,
- float_dtype: str = "float32",
- int_dtype: str = "long",
- max_cache_size: Union[float, int, str] = 0.0,
- max_cache_fd: int = 0,
- dest_sample_rate: int = 16000,
+ self,
+ path_name_type_list: Collection[Tuple[str, str, str]],
+ preprocess: Callable[
+ [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
+ ] = None,
+ float_dtype: str = "float32",
+ int_dtype: str = "long",
+ max_cache_size: Union[float, int, str] = 0.0,
+ max_cache_fd: int = 0,
+ dest_sample_rate: int = 16000,
):
assert check_argument_types()
if len(path_name_type_list) == 0:
@@ -304,8 +141,6 @@
if len(self.loader_dict[name]) == 0:
raise RuntimeError(f"{path} has no samples")
- # TODO(kamo): Should check consistency of each utt-keys?
-
if isinstance(max_cache_size, str):
max_cache_size = humanfriendly.parse_size(max_cache_size)
self.max_cache_size = max_cache_size
@@ -315,43 +150,35 @@
self.cache = None
def _build_loader(
- self, path: str, loader_type: str
+ self, path: str, loader_type: str
) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
"""Helper function to instantiate Loader.
Args:
path: The file path
- loader_type: loader_type. sound, npy, text_int, text_float, etc
+ loader_type: loader_type. sound, npy, text, etc
"""
- for key, dic in DATA_TYPES.items():
- # e.g. loader_type="sound"
- # -> return DATA_TYPES["sound"]["func"](path)
- if re.match(key, loader_type):
- kwargs = {}
- for key2 in dic["kwargs"]:
- if key2 == "loader_type":
- kwargs["loader_type"] = loader_type
- elif key2 == "dest_sample_rate" and loader_type=="sound":
- kwargs["dest_sample_rate"] = self.dest_sample_rate
- elif key2 == "float_dtype":
- kwargs["float_dtype"] = self.float_dtype
- elif key2 == "int_dtype":
- kwargs["int_dtype"] = self.int_dtype
- elif key2 == "max_cache_fd":
- kwargs["max_cache_fd"] = self.max_cache_fd
+ if loader_type == "sound":
+ loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False)
+ return AdapterForSoundScpReader(loader, self.float_dtype)
+ elif loader_type == "kaldi_ark":
+ loader = kaldiio.load_scp(path, max_cache_fd=self.max_cache_fd)
+ return AdapterForSoundScpReader(loader, self.float_dtype)
+ elif loader_type == "npy":
+ return NpyScpReader()
+ elif loader_type == "text":
+ text_loader = {}
+ with open(path, "r", encoding="utf-8") as f:
+ for linenum, line in enumerate(f, 1):
+ sps = line.rstrip().split(maxsplit=1)
+ if len(sps) == 1:
+ k, v = sps[0], ""
else:
- raise RuntimeError(f"Not implemented keyword argument: {key2}")
-
- func = dic["func"]
- try:
- return func(path, **kwargs)
- except Exception:
- if hasattr(func, "__name__"):
- name = func.__name__
- else:
- name = str(func)
- logging.error(f"An error happened with {name}({path})")
- raise
+ k, v = sps
+ if k in text_loader:
+ raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
+ text_loader[k] = v
+ return text_loader
else:
raise RuntimeError(f"Not supported: loader_type={loader_type}")
@@ -392,7 +219,7 @@
if isinstance(value, (list, tuple)):
value = np.array(value)
if not isinstance(
- value, (np.ndarray, torch.Tensor, str, numbers.Number)
+ value, (np.ndarray, torch.Tensor, str, numbers.Number)
):
raise TypeError(
f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py
new file mode 100644
index 0000000..e06a463
--- /dev/null
+++ b/funasr/datasets/small_datasets/preprocessor.py
@@ -0,0 +1,826 @@
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import scipy.signal
+import soundfile
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.cleaner import TextCleaner
+from funasr.text.token_id_converter import TokenIDConverter
+
+
+class AbsPreprocessor(ABC):
+ def __init__(self, train: bool):
+ self.train = train
+
+ @abstractmethod
+ def __call__(
+ self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ raise NotImplementedError
+
+
+def forward_segment(text, dic):
+ word_list = []
+ i = 0
+ while i < len(text):
+ longest_word = text[i]
+ for j in range(i + 1, len(text) + 1):
+ word = text[i:j]
+ if word in dic:
+ if len(word) > len(longest_word):
+ longest_word = word
+ word_list.append(longest_word)
+ i += len(longest_word)
+ return word_list
+
+
+def seg_tokenize(txt, seg_dict):
+ out_txt = ""
+ for word in txt:
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+
+def seg_tokenize_wo_pattern(txt, seg_dict):
+ out_txt = ""
+ for word in txt:
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+
+def framing(
+ x,
+ frame_length: int = 512,
+ frame_shift: int = 256,
+ centered: bool = True,
+ padded: bool = True,
+):
+ if x.size == 0:
+ raise ValueError("Input array size is zero")
+ if frame_length < 1:
+ raise ValueError("frame_length must be a positive integer")
+ if frame_length > x.shape[-1]:
+ raise ValueError("frame_length is greater than input length")
+ if 0 >= frame_shift:
+ raise ValueError("frame_shift must be greater than 0")
+
+ if centered:
+ pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
+ (frame_length // 2, frame_length // 2)
+ ]
+ x = np.pad(x, pad_shape, mode="constant", constant_values=0)
+
+ if padded:
+ # Pad to integer number of windowed segments
+ # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
+ # with integer nseg
+ nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
+ pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
+ x = np.pad(x, pad_shape, mode="constant", constant_values=0)
+
+ # Created strided array of data segments
+ if frame_length == 1 and frame_length == frame_shift:
+ result = x[..., None]
+ else:
+ shape = x.shape[:-1] + (
+ (x.shape[-1] - frame_length) // frame_shift + 1,
+ frame_length,
+ )
+ strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
+ result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
+ return result
+
+
+def detect_non_silence(
+ x: np.ndarray,
+ threshold: float = 0.01,
+ frame_length: int = 1024,
+ frame_shift: int = 512,
+ window: str = "boxcar",
+) -> np.ndarray:
+ """Power based voice activity detection.
+
+ Args:
+ x: (Channel, Time)
+ >>> x = np.random.randn(1000)
+ >>> detect = detect_non_silence(x)
+ >>> assert x.shape == detect.shape
+ >>> assert detect.dtype == np.bool
+ """
+ if x.shape[-1] < frame_length:
+ return np.full(x.shape, fill_value=True, dtype=np.bool)
+
+ if x.dtype.kind == "i":
+ x = x.astype(np.float64)
+ # framed_w: (C, T, F)
+ framed_w = framing(
+ x,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ centered=False,
+ padded=True,
+ )
+ framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
+ # power: (C, T)
+ power = (framed_w ** 2).mean(axis=-1)
+ # mean_power: (C, 1)
+ mean_power = np.mean(power, axis=-1, keepdims=True)
+ if np.all(mean_power == 0):
+ return np.full(x.shape, fill_value=True, dtype=np.bool)
+ # detect_frames: (C, T)
+ detect_frames = power / mean_power > threshold
+ # detects: (C, T, F)
+ detects = np.broadcast_to(
+ detect_frames[..., None], detect_frames.shape + (frame_shift,)
+ )
+ # detects: (C, TF)
+ detects = detects.reshape(*detect_frames.shape[:-1], -1)
+ # detects: (C, TF)
+ return np.pad(
+ detects,
+ [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
+ mode="edge",
+ )
+
+
+class CommonPreprocessor(AbsPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(train)
+ self.train = train
+ self.speech_name = speech_name
+ self.text_name = text_name
+ self.speech_volume_normalize = speech_volume_normalize
+ self.rir_apply_prob = rir_apply_prob
+ self.noise_apply_prob = noise_apply_prob
+ self.split_with_space = split_with_space
+ self.seg_dict = None
+ if seg_dict_file is not None:
+ self.seg_dict = {}
+ with open(seg_dict_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ s = line.strip().split()
+ key = s[0]
+ value = s[1:]
+ self.seg_dict[key] = " ".join(value)
+
+ if token_type is not None:
+ if token_list is None:
+ raise ValueError("token_list is required if token_type is not None")
+ self.text_cleaner = TextCleaner(text_cleaner)
+
+ self.tokenizer = build_tokenizer(
+ token_type=token_type,
+ bpemodel=bpemodel,
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ self.token_id_converter = TokenIDConverter(
+ token_list=token_list,
+ unk_symbol=unk_symbol,
+ )
+ else:
+ self.text_cleaner = None
+ self.tokenizer = None
+ self.token_id_converter = None
+
+ if train and rir_scp is not None:
+ self.rirs = []
+ with open(rir_scp, "r", encoding="utf-8") as f:
+ for line in f:
+ sps = line.strip().split(None, 1)
+ if len(sps) == 1:
+ self.rirs.append(sps[0])
+ else:
+ self.rirs.append(sps[1])
+ else:
+ self.rirs = None
+
+ if train and noise_scp is not None:
+ self.noises = []
+ with open(noise_scp, "r", encoding="utf-8") as f:
+ for line in f:
+ sps = line.strip().split(None, 1)
+ if len(sps) == 1:
+ self.noises.append(sps[0])
+ else:
+ self.noises.append(sps[1])
+ sps = noise_db_range.split("_")
+ if len(sps) == 1:
+ self.noise_db_low, self.noise_db_high = float(sps[0])
+ elif len(sps) == 2:
+ self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
+ else:
+ raise ValueError(
+ "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
+ )
+ else:
+ self.noises = None
+
+ def _speech_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, Union[str, np.ndarray]]:
+ assert check_argument_types()
+ if self.speech_name in data:
+ if self.train and (self.rirs is not None or self.noises is not None):
+ speech = data[self.speech_name]
+ nsamples = len(speech)
+
+ # speech: (Nmic, Time)
+ if speech.ndim == 1:
+ speech = speech[None, :]
+ else:
+ speech = speech.T
+ # Calc power on non shlence region
+ power = (speech[detect_non_silence(speech)] ** 2).mean()
+
+ # 1. Convolve RIR
+ if self.rirs is not None and self.rir_apply_prob >= np.random.random():
+ rir_path = np.random.choice(self.rirs)
+ if rir_path is not None:
+ rir, _ = soundfile.read(
+ rir_path, dtype=np.float64, always_2d=True
+ )
+
+ # rir: (Nmic, Time)
+ rir = rir.T
+
+ # speech: (Nmic, Time)
+ # Note that this operation doesn't change the signal length
+ speech = scipy.signal.convolve(speech, rir, mode="full")[
+ :, : speech.shape[1]
+ ]
+ # Reverse mean power to the original power
+ power2 = (speech[detect_non_silence(speech)] ** 2).mean()
+ speech = np.sqrt(power / max(power2, 1e-10)) * speech
+
+ # 2. Add Noise
+ if (
+ self.noises is not None
+ and self.noise_apply_prob >= np.random.random()
+ ):
+ noise_path = np.random.choice(self.noises)
+ if noise_path is not None:
+ noise_db = np.random.uniform(
+ self.noise_db_low, self.noise_db_high
+ )
+ with soundfile.SoundFile(noise_path) as f:
+ if f.frames == nsamples:
+ noise = f.read(dtype=np.float64, always_2d=True)
+ elif f.frames < nsamples:
+ offset = np.random.randint(0, nsamples - f.frames)
+ # noise: (Time, Nmic)
+ noise = f.read(dtype=np.float64, always_2d=True)
+ # Repeat noise
+ noise = np.pad(
+ noise,
+ [(offset, nsamples - f.frames - offset), (0, 0)],
+ mode="wrap",
+ )
+ else:
+ offset = np.random.randint(0, f.frames - nsamples)
+ f.seek(offset)
+ # noise: (Time, Nmic)
+ noise = f.read(
+ nsamples, dtype=np.float64, always_2d=True
+ )
+ if len(noise) != nsamples:
+ raise RuntimeError(f"Something wrong: {noise_path}")
+ # noise: (Nmic, Time)
+ noise = noise.T
+
+ noise_power = (noise ** 2).mean()
+ scale = (
+ 10 ** (-noise_db / 20)
+ * np.sqrt(power)
+ / np.sqrt(max(noise_power, 1e-10))
+ )
+ speech = speech + scale * noise
+
+ speech = speech.T
+ ma = np.max(np.abs(speech))
+ if ma > 1.0:
+ speech /= ma
+ data[self.speech_name] = speech
+
+ if self.speech_volume_normalize is not None:
+ speech = data[self.speech_name]
+ ma = np.max(np.abs(speech))
+ data[self.speech_name] = speech * self.speech_volume_normalize / ma
+ assert check_return_type(data)
+ return data
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ if self.text_name in data and self.tokenizer is not None:
+ text = data[self.text_name]
+ text = self.text_cleaner(text)
+ if self.split_with_space:
+ tokens = text.strip().split(" ")
+ if self.seg_dict is not None:
+ tokens = forward_segment("".join(tokens), self.seg_dict)
+ tokens = seg_tokenize(tokens, self.seg_dict)
+ else:
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[self.text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+ def __call__(
+ self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ assert check_argument_types()
+
+ data = self._speech_process(data)
+ data = self._text_process(data)
+ return data
+
+
+## FIXME
+class LMPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(train,
+ token_type,
+ token_list,
+ bpemodel,
+ text_cleaner,
+ g2p_type,
+ unk_symbol,
+ space_symbol,
+ non_linguistic_symbols,
+ delimiter,
+ rir_scp,
+ rir_apply_prob,
+ noise_scp,
+ noise_apply_prob,
+ noise_db_range,
+ speech_volume_normalize,
+ speech_name,
+ text_name,
+ split_with_space,
+ seg_dict_file,
+ )
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ if self.text_name in data and self.tokenizer is not None:
+ text = data[self.text_name]
+ text = self.text_cleaner(text)
+ if self.split_with_space:
+ tokens = text.strip().split(" ")
+ if self.seg_dict is not None:
+ tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
+ else:
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[self.text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+
+class CommonPreprocessor_multi(AbsPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ speech_name: str = "speech",
+ text_name: List[str] = ["text"],
+ ):
+ super().__init__(train)
+ self.train = train
+ self.speech_name = speech_name
+ self.text_name = text_name
+
+ if token_type is not None:
+ if token_list is None:
+ raise ValueError("token_list is required if token_type is not None")
+ self.text_cleaner = TextCleaner(text_cleaner)
+
+ self.tokenizer = build_tokenizer(
+ token_type=token_type,
+ bpemodel=bpemodel,
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ self.token_id_converter = TokenIDConverter(
+ token_list=token_list,
+ unk_symbol=unk_symbol,
+ )
+ else:
+ self.text_cleaner = None
+ self.tokenizer = None
+ self.token_id_converter = None
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ for text_n in self.text_name:
+ if text_n in data and self.tokenizer is not None:
+ text = data[text_n]
+ text = self.text_cleaner(text)
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[text_n] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+ def __call__(
+ self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ assert check_argument_types()
+
+ if self.speech_name in data:
+ # Nothing now: candidates:
+ # - STFT
+ # - Fbank
+ # - CMVN
+ # - Data augmentation
+ pass
+
+ data = self._text_process(data)
+ return data
+
+
+class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: List[str] = [None],
+ token_list: List[Union[Path, str, Iterable[str]]] = [None],
+ bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: List[str] = ["text"],
+ ):
+ # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+ super().__init__(
+ train=train,
+ token_type=token_type[0],
+ token_list=token_list[0],
+ bpemodel=bpemodel[0],
+ text_cleaner=text_cleaner,
+ g2p_type=g2p_type,
+ unk_symbol=unk_symbol,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ delimiter=delimiter,
+ speech_name=speech_name,
+ text_name=text_name[0],
+ rir_scp=rir_scp,
+ rir_apply_prob=rir_apply_prob,
+ noise_scp=noise_scp,
+ noise_apply_prob=noise_apply_prob,
+ noise_db_range=noise_db_range,
+ speech_volume_normalize=speech_volume_normalize,
+ )
+
+ assert (
+ len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+ ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+ self.num_tokenizer = len(token_type)
+ self.tokenizer = []
+ self.token_id_converter = []
+
+ for i in range(self.num_tokenizer):
+ if token_type[i] is not None:
+ if token_list[i] is None:
+ raise ValueError("token_list is required if token_type is not None")
+
+ self.tokenizer.append(
+ build_tokenizer(
+ token_type=token_type[i],
+ bpemodel=bpemodel[i],
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ )
+ self.token_id_converter.append(
+ TokenIDConverter(
+ token_list=token_list[i],
+ unk_symbol=unk_symbol,
+ )
+ )
+ else:
+ self.tokenizer.append(None)
+ self.token_id_converter.append(None)
+
+ self.text_cleaner = TextCleaner(text_cleaner)
+ self.text_name = text_name # override the text_name from CommonPreprocessor
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ for i in range(self.num_tokenizer):
+ text_name = self.text_name[i]
+ if text_name in data and self.tokenizer[i] is not None:
+ text = data[text_name]
+ text = self.text_cleaner(text)
+ tokens = self.tokenizer[i].text2tokens(text)
+ text_ints = self.token_id_converter[i].tokens2ids(tokens)
+ data[text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+
+class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_text_name: str = "split_text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(
+ train=train,
+ # Force to use word.
+ token_type="word",
+ token_list=token_list,
+ bpemodel=bpemodel,
+ text_cleaner=text_cleaner,
+ g2p_type=g2p_type,
+ unk_symbol=unk_symbol,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ delimiter=delimiter,
+ speech_name=speech_name,
+ text_name=text_name,
+ rir_scp=rir_scp,
+ rir_apply_prob=rir_apply_prob,
+ noise_scp=noise_scp,
+ noise_apply_prob=noise_apply_prob,
+ noise_db_range=noise_db_range,
+ speech_volume_normalize=speech_volume_normalize,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file,
+ )
+ # The data field name for split text.
+ self.split_text_name = split_text_name
+
+ @classmethod
+ def split_words(cls, text: str):
+ words = []
+ segs = text.split()
+ for seg in segs:
+ # There is no space in seg.
+ current_word = ""
+ for c in seg:
+ if len(c.encode()) == 1:
+ # This is an ASCII char.
+ current_word += c
+ else:
+ # This is a Chinese char.
+ if len(current_word) > 0:
+ words.append(current_word)
+ current_word = ""
+ words.append(c)
+ if len(current_word) > 0:
+ words.append(current_word)
+ return words
+
+ def __call__(
+ self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
+ ) -> Dict[str, Union[list, np.ndarray]]:
+ assert check_argument_types()
+ # Split words.
+ if isinstance(data[self.text_name], str):
+ split_text = self.split_words(data[self.text_name])
+ else:
+ split_text = data[self.text_name]
+ data[self.text_name] = " ".join(split_text)
+ data = self._speech_process(data)
+ data = self._text_process(data)
+ data[self.split_text_name] = split_text
+ return data
+
+ def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
+ result = data[self.split_text_name]
+ del data[self.split_text_name]
+ return result
+
+
+class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: List[str] = [None],
+ token_list: List[Union[Path, str, Iterable[str]]] = [None],
+ bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: List[str] = ["text"],
+ vad_name: str = "vad_indexes",
+ ):
+ # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+ super().__init__(
+ train=train,
+ token_type=token_type[0],
+ token_list=token_list[0],
+ bpemodel=bpemodel[0],
+ text_cleaner=text_cleaner,
+ g2p_type=g2p_type,
+ unk_symbol=unk_symbol,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ delimiter=delimiter,
+ speech_name=speech_name,
+ text_name=text_name[0],
+ rir_scp=rir_scp,
+ rir_apply_prob=rir_apply_prob,
+ noise_scp=noise_scp,
+ noise_apply_prob=noise_apply_prob,
+ noise_db_range=noise_db_range,
+ speech_volume_normalize=speech_volume_normalize,
+ )
+
+ assert (
+ len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+ ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+ self.num_tokenizer = len(token_type)
+ self.tokenizer = []
+ self.token_id_converter = []
+
+ for i in range(self.num_tokenizer):
+ if token_type[i] is not None:
+ if token_list[i] is None:
+ raise ValueError("token_list is required if token_type is not None")
+
+ self.tokenizer.append(
+ build_tokenizer(
+ token_type=token_type[i],
+ bpemodel=bpemodel[i],
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ )
+ self.token_id_converter.append(
+ TokenIDConverter(
+ token_list=token_list[i],
+ unk_symbol=unk_symbol,
+ )
+ )
+ else:
+ self.tokenizer.append(None)
+ self.token_id_converter.append(None)
+
+ self.text_cleaner = TextCleaner(text_cleaner)
+ self.text_name = text_name # override the text_name from CommonPreprocessor
+ self.vad_name = vad_name
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ for i in range(self.num_tokenizer):
+ text_name = self.text_name[i]
+ if text_name in data and self.tokenizer[i] is not None:
+ text = data[text_name]
+ text = self.text_cleaner(text)
+ tokens = self.tokenizer[i].text2tokens(text)
+ if "vad:" in tokens[-1]:
+ vad = tokens[-1][4:]
+ tokens = tokens[:-1]
+ if len(vad) == 0:
+ vad = -1
+ else:
+ vad = int(vad)
+ data[self.vad_name] = np.array([vad], dtype=np.int64)
+ text_ints = self.token_id_converter[i].tokens2ids(tokens)
+ data[text_name] = np.array(text_ints, dtype=np.int64)
+
+
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+ assert word_limit > 1
+ if len(words) <= word_limit:
+ return [words]
+ sentences = []
+ length = len(words)
+ sentence_len = length // word_limit
+ for i in range(sentence_len):
+ sentences.append(words[i * word_limit:(i + 1) * word_limit])
+ if length % word_limit > 0:
+ sentences.append(words[sentence_len * word_limit:])
+ return sentences
+
+
+def build_preprocess(args):
+ if args.task_name == "asr":
+ pass
+ else:
+ raise ValueError(f"Not supported task={args.task_name}")
--
Gitblit v1.9.1