游雁
2023-12-21 c8bae0ec85eee25d66de6b1e4502eff74d750b24
funasr/datasets/audio_datasets/datasets.py
@@ -24,6 +24,17 @@
      super().__init__()
      index_ds_class = registry_tables.index_ds_classes.get(index_ds.lower())
      self.index_ds = index_ds_class(path)
      preprocessor_speech = kwargs.get("preprocessor_speech", None)
      if preprocessor_speech:
         preprocessor_speech_class = registry_tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
         preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
      self.preprocessor_speech = preprocessor_speech
      preprocessor_text = kwargs.get("preprocessor_text", None)
      if preprocessor_text:
         preprocessor_text_class = registry_tables.preprocessor_text_classes.get(preprocessor_text.lower())
         preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
      self.preprocessor_text = preprocessor_text
      self.frontend = frontend
      self.fs = 16000 if frontend is None else frontend.fs
      self.data_type = "sound"
@@ -49,8 +60,13 @@
      # pdb.set_trace()
      source = item["source"]
      data_src = load_audio(source, fs=self.fs)
      if self.preprocessor_speech:
         data_src = self.preprocessor_speech(data_src)
      speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
      target = item["target"]
      if self.preprocessor_text:
         target = self.preprocessor_text(target)
      ids = self.tokenizer.encode(target)
      ids_lengths = len(ids)
      text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)