| | |
| | | import logging |
| | | |
| | | import re |
| | | import torch |
| | | import random |
| | | |
| | | import traceback |
| | | from funasr.register import tables |
| | | from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video |
| | | |
| | |
| | | self.prompt_ids_len = 0 |
| | | self.retry = kwargs.get("retry", 5) |
| | | |
| | | self.permute = False |
| | | from funasr.frontends.whisper_frontend import WhisperFrontend |
| | | |
| | | if isinstance(self.frontend, WhisperFrontend): |
| | | self.permute = True |
| | | |
| | | def get_source_len(self, index): |
| | | item = self.index_ds[index] |
| | | return self.index_ds.get_source_len(item) |
| | |
| | | return len(self.index_ds) |
| | | |
| | | def __getitem__(self, index): |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | |
| | | output = None |
| | | for idx in range(self.retry): |
| | | if idx == 0: |
| | | index_cur = index |
| | | else: |
| | | if index <= self.retry: |
| | | index_cur = index + idx |
| | | else: |
| | | index_cur = torch.randint(0, index, ()).item() |
| | | index_cur = torch.randint(0, len(self.index_ds), ()).item() |
| | | |
| | | item = self.index_ds[index_cur] |
| | | |
| | | source = item["source"] |
| | | data_src = load_audio_text_image_video(source, fs=self.fs) |
| | | try: |
| | | data_src = load_audio_text_image_video(source, fs=self.fs) |
| | | except Exception as e: |
| | | logging.error(f"Loading wav failed! {str(e)}, {traceback.format_exc()}") |
| | | continue |
| | | |
| | | if self.preprocessor_speech: |
| | | data_src = self.preprocessor_speech(data_src, fs=self.fs) |
| | | speech, speech_lengths = extract_fbank( |
| | |
| | | |
| | | if speech_lengths > self.batch_size: |
| | | continue |
| | | speech = speech.permute(0, 2, 1) |
| | | if self.permute: |
| | | speech = speech.permute(0, 2, 1) |
| | | target = item["target"] |
| | | if self.preprocessor_text: |
| | | target = self.preprocessor_text(target) |
| | |
| | | task = item.get("prompt", "<|ASR|>") |
| | | text_language = item.get("text_language", "<|zh|>") |
| | | |
| | | prompt = f"{self.sos}{task}{text_language}" |
| | | prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") |
| | | if isinstance(self.sos, str): |
| | | prompt = f"{self.sos}{task}{text_language}" |
| | | prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") |
| | | else: |
| | | prompt = f"{task}{text_language}" |
| | | prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") |
| | | prompt_ids = [self.sos] + prompt_ids |
| | | |
| | | prompt_ids_len = len(prompt_ids) - 1 # [sos, task] |
| | | self.prompt_ids_len = prompt_ids_len |
| | | |
| | |
| | | if target_ids_len > 200: |
| | | continue |
| | | |
| | | eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] |
| | | if isinstance(self.eos, str): |
| | | eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] |
| | | else: |
| | | eos = [self.eos] |
| | | |
| | | ids = prompt_ids + target_ids + eos |
| | | ids = prompt_ids + target_ids + eos # [sos, task, lid, text, eos] |
| | | ids_lengths = len(ids) |
| | | |
| | | text = torch.tensor(ids, dtype=torch.int64) |