From a9e857e45250b16af60d5fe3efcd06e685f6506a Mon Sep 17 00:00:00 2001
From: lzr265946 <lzr265946@alibaba-inc.com>
Date: 星期六, 03 十二月 2022 16:39:38 +0800
Subject: [PATCH] update funasr 0.1.3

---
 funasr/tasks/abs_task.py |   61 ++++++++++++++++++++++++++++++
 1 files changed, 60 insertions(+), 1 deletions(-)

diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 5ea78c3..d716423 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -38,6 +38,7 @@
 from funasr.datasets.dataset import DATA_TYPES
 from funasr.datasets.dataset import ESPnetDataset
 from funasr.datasets.iterable_dataset import IterableESPnetDataset
+from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope
 from funasr.iterators.abs_iter_factory import AbsIterFactory
 from funasr.iterators.chunk_iter_factory import ChunkIterFactory
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
@@ -1026,7 +1027,7 @@
     @classmethod
     def check_task_requirements(
             cls,
-            dataset: Union[AbsDataset, IterableESPnetDataset],
+            dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope],
             allow_variable_data_keys: bool,
             train: bool,
             inference: bool = False,
@@ -1748,6 +1749,64 @@
             **kwargs,
         )
 
+    @classmethod
+    def build_streaming_iterator_modelscope(
+            cls,
+            data_path_and_name_and_type,
+            preprocess_fn,
+            collate_fn,
+            key_file: str = None,
+            batch_size: int = 1,
+            dtype: str = np.float32,
+            num_workers: int = 1,
+            allow_variable_data_keys: bool = False,
+            ngpu: int = 0,
+            inference: bool = False,
+            sample_rate: Union[dict, int] = 16000
+    ) -> DataLoader:
+        """Build DataLoader using iterable dataset"""
+        assert check_argument_types()
+        # For backward compatibility for pytorch DataLoader
+        if collate_fn is not None:
+            kwargs = dict(collate_fn=collate_fn)
+        else:
+            kwargs = {}
+
+        audio_data = data_path_and_name_and_type[0]
+        if isinstance(audio_data, bytes):
+            dataset = IterableESPnetBytesModelScope(
+                data_path_and_name_and_type,
+                float_dtype=dtype,
+                preprocess=preprocess_fn,
+                key_file=key_file,
+                sample_rate=sample_rate
+            )
+        else:
+            dataset = IterableESPnetDatasetModelScope(
+                data_path_and_name_and_type,
+                float_dtype=dtype,
+                preprocess=preprocess_fn,
+                key_file=key_file,
+                sample_rate=sample_rate
+            )
+
+        if dataset.apply_utt2category:
+            kwargs.update(batch_size=1)
+        else:
+            kwargs.update(batch_size=batch_size)
+
+        cls.check_task_requirements(dataset,
+                                    allow_variable_data_keys,
+                                    train=False,
+                                    inference=inference)
+
+        return DataLoader(
+            dataset=dataset,
+            pin_memory=ngpu > 0,
+            num_workers=num_workers,
+            **kwargs,
+        )
+
     # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
     @classmethod
     def build_model_from_file(

--
Gitblit v1.9.1