| | |
| | | from funasr.datasets.dataset import DATA_TYPES |
| | | from funasr.datasets.dataset import ESPnetDataset |
| | | from funasr.datasets.iterable_dataset import IterableESPnetDataset |
| | | from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope |
| | | from funasr.iterators.abs_iter_factory import AbsIterFactory |
| | | from funasr.iterators.chunk_iter_factory import ChunkIterFactory |
| | | from funasr.iterators.multiple_iter_factory import MultipleIterFactory |
| | |
| | | @classmethod |
| | | def check_task_requirements( |
| | | cls, |
| | | dataset: Union[AbsDataset, IterableESPnetDataset], |
| | | dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope], |
| | | allow_variable_data_keys: bool, |
| | | train: bool, |
| | | inference: bool = False, |
| | |
| | | **kwargs, |
| | | ) |
| | | |
| | | @classmethod |
| | | def build_streaming_iterator_modelscope( |
| | | cls, |
| | | data_path_and_name_and_type, |
| | | preprocess_fn, |
| | | collate_fn, |
| | | key_file: str = None, |
| | | batch_size: int = 1, |
| | | dtype: str = np.float32, |
| | | num_workers: int = 1, |
| | | allow_variable_data_keys: bool = False, |
| | | ngpu: int = 0, |
| | | inference: bool = False, |
| | | sample_rate: Union[dict, int] = 16000 |
| | | ) -> DataLoader: |
| | | """Build DataLoader using iterable dataset""" |
| | | assert check_argument_types() |
| | | # For backward compatibility for pytorch DataLoader |
| | | if collate_fn is not None: |
| | | kwargs = dict(collate_fn=collate_fn) |
| | | else: |
| | | kwargs = {} |
| | | |
| | | audio_data = data_path_and_name_and_type[0] |
| | | if isinstance(audio_data, bytes): |
| | | dataset = IterableESPnetBytesModelScope( |
| | | data_path_and_name_and_type, |
| | | float_dtype=dtype, |
| | | preprocess=preprocess_fn, |
| | | key_file=key_file, |
| | | sample_rate=sample_rate |
| | | ) |
| | | else: |
| | | dataset = IterableESPnetDatasetModelScope( |
| | | data_path_and_name_and_type, |
| | | float_dtype=dtype, |
| | | preprocess=preprocess_fn, |
| | | key_file=key_file, |
| | | sample_rate=sample_rate |
| | | ) |
| | | |
| | | if dataset.apply_utt2category: |
| | | kwargs.update(batch_size=1) |
| | | else: |
| | | kwargs.update(batch_size=batch_size) |
| | | |
| | | cls.check_task_requirements(dataset, |
| | | allow_variable_data_keys, |
| | | train=False, |
| | | inference=inference) |
| | | |
| | | return DataLoader( |
| | | dataset=dataset, |
| | | pin_memory=ngpu > 0, |
| | | num_workers=num_workers, |
| | | **kwargs, |
| | | ) |
| | | |
| | | # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ |
| | | @classmethod |
| | | def build_model_from_file( |