游雁
2024-01-15 2a0b2c795b161a0bd56e026c53eb605fea9e142c
funasr/datasets/audio_datasets/datasets.py
@@ -8,11 +8,14 @@
import time
import logging
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.register import tables
@tables.register("dataset_classes", "AudioDataset")
class AudioDataset(torch.utils.data.Dataset):
   """
   AudioDataset
   """
   def __init__(self,
                path,
                index_ds: str = None,
@@ -22,16 +25,16 @@
                float_pad_value: float = 0.0,
                 **kwargs):
      super().__init__()
      index_ds_class = tables.index_ds_classes.get(index_ds.lower())
      index_ds_class = tables.index_ds_classes.get(index_ds)
      self.index_ds = index_ds_class(path)
      preprocessor_speech = kwargs.get("preprocessor_speech", None)
      if preprocessor_speech:
         preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
         preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
         preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
      self.preprocessor_speech = preprocessor_speech
      preprocessor_text = kwargs.get("preprocessor_text", None)
      if preprocessor_text:
         preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text.lower())
         preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
         preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
      self.preprocessor_text = preprocessor_text