zhifu gao
2024-05-14 2f27b165559cd53afab52047309ebe4ac838ebb8
funasr/datasets/sense_voice_datasets/datasets.py
@@ -2,7 +2,7 @@
import torch
import random
import traceback
from funasr.register import tables
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
@@ -73,15 +73,17 @@
            if idx == 0:
                index_cur = index
            else:
                if index <= self.retry:
                    index_cur = index + idx
                else:
                    index_cur = torch.randint(0, index, ()).item()
                index_cur = torch.randint(0, len(self.index_ds), ()).item()
            item = self.index_ds[index_cur]
            source = item["source"]
            data_src = load_audio_text_image_video(source, fs=self.fs)
            try:
                data_src = load_audio_text_image_video(source, fs=self.fs)
            except Exception as e:
                logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}")
                continue
            if self.preprocessor_speech:
                data_src = self.preprocessor_speech(data_src, fs=self.fs)
            speech, speech_lengths = extract_fbank(
@@ -110,7 +112,7 @@
            eos = self.tokenizer.encode(self.eos, allowed_special="all")  # [eos]
            ids = prompt_ids + target_ids + eos
            ids = prompt_ids + target_ids + eos  # [sos, task, lid, text, eos]
            ids_lengths = len(ids)
            text = torch.tensor(ids, dtype=torch.int64)
@@ -186,7 +188,7 @@
                )
        if self.batch_type != "example":
            for i in range(3):
            for i in range(10):
                outputs = self._filter_badcase(outputs, i=i)
        return outputs