From a9e857e45250b16af60d5fe3efcd06e685f6506a Mon Sep 17 00:00:00 2001
From: lzr265946 <lzr265946@alibaba-inc.com>
Date: 星期六, 03 十二月 2022 16:39:38 +0800
Subject: [PATCH] update funasr 0.1.3
---
funasr/tasks/abs_task.py | 61 ++++++++++++++++++++++++++++++
1 files changed, 60 insertions(+), 1 deletions(-)
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 5ea78c3..d716423 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -38,6 +38,7 @@
from funasr.datasets.dataset import DATA_TYPES
from funasr.datasets.dataset import ESPnetDataset
from funasr.datasets.iterable_dataset import IterableESPnetDataset
+from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.iterators.chunk_iter_factory import ChunkIterFactory
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
@@ -1026,7 +1027,7 @@
@classmethod
def check_task_requirements(
cls,
- dataset: Union[AbsDataset, IterableESPnetDataset],
+ dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope],
allow_variable_data_keys: bool,
train: bool,
inference: bool = False,
@@ -1748,6 +1749,64 @@
**kwargs,
)
+ @classmethod
+ def build_streaming_iterator_modelscope(
+ cls,
+ data_path_and_name_and_type,
+ preprocess_fn,
+ collate_fn,
+ key_file: str = None,
+ batch_size: int = 1,
+ dtype: str = np.float32,
+ num_workers: int = 1,
+ allow_variable_data_keys: bool = False,
+ ngpu: int = 0,
+ inference: bool = False,
+ sample_rate: Union[dict, int] = 16000
+ ) -> DataLoader:
+ """Build DataLoader using iterable dataset"""
+ assert check_argument_types()
+ # For backward compatibility for pytorch DataLoader
+ if collate_fn is not None:
+ kwargs = dict(collate_fn=collate_fn)
+ else:
+ kwargs = {}
+
+ audio_data = data_path_and_name_and_type[0]
+ if isinstance(audio_data, bytes):
+ dataset = IterableESPnetBytesModelScope(
+ data_path_and_name_and_type,
+ float_dtype=dtype,
+ preprocess=preprocess_fn,
+ key_file=key_file,
+ sample_rate=sample_rate
+ )
+ else:
+ dataset = IterableESPnetDatasetModelScope(
+ data_path_and_name_and_type,
+ float_dtype=dtype,
+ preprocess=preprocess_fn,
+ key_file=key_file,
+ sample_rate=sample_rate
+ )
+
+ if dataset.apply_utt2category:
+ kwargs.update(batch_size=1)
+ else:
+ kwargs.update(batch_size=batch_size)
+
+ cls.check_task_requirements(dataset,
+ allow_variable_data_keys,
+ train=False,
+ inference=inference)
+
+ return DataLoader(
+ dataset=dataset,
+ pin_memory=ngpu > 0,
+ num_workers=num_workers,
+ **kwargs,
+ )
+
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
@classmethod
def build_model_from_file(
--
Gitblit v1.9.1