| | |
| | | import time |
| | | import logging |
| | | |
| | | from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.register import tables |
| | | |
| | | @tables.register("dataset_classes", "AudioDataset") |
| | | class AudioDataset(torch.utils.data.Dataset): |
| | | """ |
| | | AudioDataset |
| | | """ |
| | | def __init__(self, |
| | | path, |
| | | index_ds: str = None, |
| | |
| | | float_pad_value: float = 0.0, |
| | | **kwargs): |
| | | super().__init__() |
| | | index_ds_class = tables.index_ds_classes.get(index_ds.lower()) |
| | | index_ds_class = tables.index_ds_classes.get(index_ds) |
| | | self.index_ds = index_ds_class(path) |
| | | preprocessor_speech = kwargs.get("preprocessor_speech", None) |
| | | if preprocessor_speech: |
| | | preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech.lower()) |
| | | preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech) |
| | | preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf")) |
| | | self.preprocessor_speech = preprocessor_speech |
| | | preprocessor_text = kwargs.get("preprocessor_text", None) |
| | | if preprocessor_text: |
| | | preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text.lower()) |
| | | preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text) |
| | | preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf")) |
| | | self.preprocessor_text = preprocessor_text |
| | | |