| | |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | ): |
| | | |
| | | audio_mask = kwargs.get("audio_mask", None) |
| | | audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None |
| | |
| | | model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None) |
| | | preds = torch.argmax(model_outputs.logits, -1) |
| | | text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True) |
| | | text = text[0].split(': \n')[-1] |
| | | |
| | | text = text[0].split(': ')[-1] |
| | | text = text.strip() |
| | | |
| | | # preds = torch.argmax(model_outputs.logits, -1) |
| | | |
| | | ibest_writer = None |