lzr265946
2022-12-03 a9e857e45250b16af60d5fe3efcd06e685f6506a
funasr/tasks/abs_task.py
@@ -38,6 +38,7 @@
from funasr.datasets.dataset import DATA_TYPES
from funasr.datasets.dataset import ESPnetDataset
from funasr.datasets.iterable_dataset import IterableESPnetDataset
from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.iterators.chunk_iter_factory import ChunkIterFactory
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
@@ -1026,7 +1027,7 @@
    @classmethod
    def check_task_requirements(
            cls,
            dataset: Union[AbsDataset, IterableESPnetDataset],
            dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope],
            allow_variable_data_keys: bool,
            train: bool,
            inference: bool = False,
@@ -1748,6 +1749,64 @@
            **kwargs,
        )
    @classmethod
    def build_streaming_iterator_modelscope(
            cls,
            data_path_and_name_and_type,
            preprocess_fn,
            collate_fn,
            key_file: str = None,
            batch_size: int = 1,
            dtype: str = np.float32,
            num_workers: int = 1,
            allow_variable_data_keys: bool = False,
            ngpu: int = 0,
            inference: bool = False,
            sample_rate: Union[dict, int] = 16000
    ) -> DataLoader:
        """Build DataLoader using iterable dataset"""
        assert check_argument_types()
        # For backward compatibility for pytorch DataLoader
        if collate_fn is not None:
            kwargs = dict(collate_fn=collate_fn)
        else:
            kwargs = {}
        audio_data = data_path_and_name_and_type[0]
        if isinstance(audio_data, bytes):
            dataset = IterableESPnetBytesModelScope(
                data_path_and_name_and_type,
                float_dtype=dtype,
                preprocess=preprocess_fn,
                key_file=key_file,
                sample_rate=sample_rate
            )
        else:
            dataset = IterableESPnetDatasetModelScope(
                data_path_and_name_and_type,
                float_dtype=dtype,
                preprocess=preprocess_fn,
                key_file=key_file,
                sample_rate=sample_rate
            )
        if dataset.apply_utt2category:
            kwargs.update(batch_size=1)
        else:
            kwargs.update(batch_size=batch_size)
        cls.check_task_requirements(dataset,
                                    allow_variable_data_keys,
                                    train=False,
                                    inference=inference)
        return DataLoader(
            dataset=dataset,
            pin_memory=ngpu > 0,
            num_workers=num_workers,
            **kwargs,
        )
    # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
    @classmethod
    def build_model_from_file(