| | |
| | | |
| | | import torch |
| | | import random |
| | | |
| | | import traceback |
| | | from funasr.register import tables |
| | | from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video |
| | | |
| | |
| | | if idx == 0: |
| | | index_cur = index |
| | | else: |
| | | if index <= self.retry: |
| | | index_cur = index + idx |
| | | else: |
| | | index_cur = torch.randint(0, index, ()).item() |
| | | index_cur = torch.randint(0, len(self.index_ds), ()).item() |
| | | |
| | | item = self.index_ds[index_cur] |
| | | |
| | | source = item["source"] |
| | | data_src = load_audio_text_image_video(source, fs=self.fs) |
| | | try: |
| | | data_src = load_audio_text_image_video(source, fs=self.fs) |
| | | except Exception as e: |
| | | logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}") |
| | | continue |
| | | |
| | | if self.preprocessor_speech: |
| | | data_src = self.preprocessor_speech(data_src, fs=self.fs) |
| | | speech, speech_lengths = extract_fbank( |
| | |
| | | |
| | | eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] |
| | | |
| | | ids = prompt_ids + target_ids + eos |
| | | ids = prompt_ids + target_ids + eos # [sos, task, lid, text, eos] |
| | | ids_lengths = len(ids) |
| | | |
| | | text = torch.tensor(ids, dtype=torch.int64) |
| | |
| | | ) |
| | | |
| | | if self.batch_type != "example": |
| | | for i in range(3): |
| | | for i in range(10): |
| | | outputs = self._filter_badcase(outputs, i=i) |
| | | |
| | | return outputs |