| | |
| | | |
| | | @torch.no_grad() |
| | | def detect_language( |
| | | model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None |
| | | model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, initial_prompt = None, x = None, |
| | | ) -> Tuple[Tensor, List[dict]]: |
| | | """ |
| | | Detect the spoken language in the audio, and return them as list of strings, along with the ids |
| | |
| | | mel = mel.unsqueeze(0) |
| | | |
| | | # skip encoder forward pass if already-encoded audio features were given |
| | | if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): |
| | | # FIX(funasr): sense vocie |
| | | if mel.shape[-1] != model.dims.n_audio_state: |
| | | mel = model.encoder(mel) |
| | | |
| | | # forward pass using a single token, startoftranscript |
| | | n_audio = mel.shape[0] |
| | | x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] |
| | | # FIX(funasr): sense vocie |
| | | # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] |
| | | if x is None: |
| | | x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1] |
| | | logits = model.logits(x, mel)[:, 0] |
| | | |
| | | # collect detected languages; suppress all non-language tokens |
| | |
| | | |
| | | # implementation details |
| | | fp16: bool = True # use fp16 for most of the calculation |
| | | |
| | | # FIX(funasr): sense vocie |
| | | initial_prompt: str = None |
| | | |
| | | |
| | | @dataclass(frozen=True) |
| | |
| | | + prompt_tokens[-(self.n_ctx // 2 - 1) :] |
| | | + tokens |
| | | ) |
| | | #FIX(gzf): sense vocie |
| | | if initial_prompt := self.options.initial_prompt: |
| | | tokens = self.tokenizer.encode(initial_prompt, allowed_special="all") |
| | | if self.options.language is None: |
| | | tokens += [0] |
| | | |
| | | |
| | | return tuple(tokens) |
| | | |
| | |
| | | |
| | | if self.options.language is None or self.options.task == "lang_id": |
| | | lang_tokens, lang_probs = self.model.detect_language( |
| | | audio_features, self.tokenizer |
| | | audio_features, self.tokenizer, x=tokens |
| | | ) |
| | | languages = [max(probs, key=probs.get) for probs in lang_probs] |
| | | # FIX(funasr): sense vocie |
| | | # if self.options.language is None: |
| | | # tokens[:, self.sot_index + 1] = lang_tokens # write language tokens |
| | | if self.options.language is None: |
| | | tokens[:, self.sot_index + 1] = lang_tokens # write language tokens |
| | | # tokens[:, self.sot_index + 1] = lang_tokens # write language tokens |
| | | languages = "".join([f"<|{language}|>" for language in languages]) |
| | | n_audio = audio_features.shape[0] |
| | | lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to( |
| | | audio_features.device) # [n_audio, 1] |
| | | |
| | | tokens[:, -1:] = lang_tokens[:, :] |
| | | languages = [languages] |
| | | |
| | | return languages, lang_probs |
| | | |