| | |
| | | import logging |
| | | |
| | | import re |
| | | import torch |
| | | import random |
| | | import traceback |
| | |
| | | self.prompt_ids_len = 0 |
| | | self.retry = kwargs.get("retry", 5) |
| | | |
| | | self.permute = False |
| | | from funasr.frontends.whisper_frontend import WhisperFrontend |
| | | |
| | | if isinstance(self.frontend, WhisperFrontend): |
| | | self.permute = True |
| | | |
| | | def get_source_len(self, index): |
| | | item = self.index_ds[index] |
| | | return self.index_ds.get_source_len(item) |
| | |
| | | return len(self.index_ds) |
| | | |
| | | def __getitem__(self, index): |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | |
| | | output = None |
| | | for idx in range(self.retry): |
| | |
| | | |
| | | if speech_lengths > self.batch_size: |
| | | continue |
| | | speech = speech.permute(0, 2, 1) |
| | | if self.permute: |
| | | speech = speech.permute(0, 2, 1) |
| | | target = item["target"] |
| | | if self.preprocessor_text: |
| | | target = self.preprocessor_text(target) |
| | |
| | | task = item.get("prompt", "<|ASR|>") |
| | | text_language = item.get("text_language", "<|zh|>") |
| | | |
| | | prompt = f"{self.sos}{task}{text_language}" |
| | | prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") |
| | | if isinstance(self.sos, str): |
| | | prompt = f"{self.sos}{task}{text_language}" |
| | | prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") |
| | | else: |
| | | prompt = f"{task}{text_language}" |
| | | prompt_ids = self.tokenizer.encode(prompt, allowed_special="all") |
| | | prompt_ids = [self.sos] + prompt_ids |
| | | |
| | | prompt_ids_len = len(prompt_ids) - 1 # [sos, task] |
| | | self.prompt_ids_len = prompt_ids_len |
| | | |
| | |
| | | if target_ids_len > 200: |
| | | continue |
| | | |
| | | eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] |
| | | if isinstance(self.eos, str): |
| | | eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos] |
| | | else: |
| | | eos = [self.eos] |
| | | |
| | | ids = prompt_ids + target_ids + eos |
| | | ids = prompt_ids + target_ids + eos # [sos, task, lid, text, eos] |
| | | ids_lengths = len(ids) |
| | | |
| | | text = torch.tensor(ids, dtype=torch.int64) |