zhifu gao
2024-04-10 112c8e6eb7c7b5dd150787848a5c1d6111688ed7
funasr/models/sense_voice/whisper_lib/decoding.py
@@ -58,18 +58,20 @@
    # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
    if x is None:
        x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device)  # [n_audio, 1]
    logits = model.logits(x, mel)[:, 0]
    logits = model.logits(x[:,:-1], mel)[:, -1]
    # collect detected languages; suppress all non-language tokens
    mask = torch.ones(logits.shape[-1], dtype=torch.bool)
    mask[list(tokenizer.all_language_tokens)] = False
    mask[tokenizer.no_speech] = False
    logits[:, mask] = -np.inf
    language_tokens = logits.argmax(dim=-1)
    language_token_probs = logits.softmax(dim=-1).cpu()
    language_probs = [
        {
            c: language_token_probs[i, j].item()
            for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
            for j, c in zip(list(tokenizer.all_language_tokens) + [tokenizer.no_speech], list(tokenizer.all_language_codes) + ["nospeech"])
        }
        for i in range(n_audio)
    ]