From 9afcf0ea7d2877ddbbafec5b1a77f5cf025dab17 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 12 六月 2024 17:17:03 +0800
Subject: [PATCH] decoding
---
funasr/models/sense_voice/whisper_lib/model.py | 396 +++++++++++++++++++++++++++++++++++++++++++++-----------
1 files changed, 316 insertions(+), 80 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index 2822fc7..3d0d6a8 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -1,97 +1,333 @@
+import base64
+import gzip
from dataclasses import dataclass
-from typing import Dict
-from typing import Iterable, Optional
-import time
+from typing import Dict, Iterable, Optional
+
import numpy as np
import torch
import torch.nn.functional as F
-from torch import Tensor
-from torch import nn
-from . import whisper_lib as whisper
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from torch import Tensor, nn
-from funasr.register import tables
+from .decoding import decode as decode_function
+from .decoding import detect_language as detect_language_function
+from .transcribe import transcribe as transcribe_function
-@tables.register("model_classes", "SenseVoice")
-class SenseVoice(nn.Module):
+@dataclass
+class ModelDimensions:
+ n_mels: int
+ n_audio_ctx: int
+ n_audio_state: int
+ n_audio_head: int
+ n_audio_layer: int
+ n_vocab: int
+ n_text_ctx: int
+ n_text_state: int
+ n_text_head: int
+ n_text_layer: int
+
+
+# class LayerNorm(nn.LayerNorm):
+# def forward(self, x: Tensor) -> Tensor:
+# return super().forward(x.float()).type(x.dtype)
+
+
+class LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.layer_norm(
+ input.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
+
+
+class Linear(nn.Linear):
+ def forward(self, x: Tensor) -> Tensor:
+ return F.linear(
+ x,
+ self.weight.to(x.dtype),
+ None if self.bias is None else self.bias.to(x.dtype),
+ )
+
+
+class Conv1d(nn.Conv1d):
+ def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
+ return super()._conv_forward(
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
+ )
+
+
+def sinusoids(length, channels, max_timescale=10000):
+ """Returns sinusoids for positional embedding"""
+ assert channels % 2 == 0
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, n_state: int, n_head: int):
super().__init__()
- hub = kwargs.get("hub", "funasr")
+ self.n_head = n_head
+ self.query = Linear(n_state, n_state)
+ self.key = Linear(n_state, n_state, bias=False)
+ self.value = Linear(n_state, n_state)
+ self.out = Linear(n_state, n_state)
- dims = kwargs.get("dims", {})
- dims = whisper.model.ModelDimensions(**dims)
- model = whisper.model.Whisper(dims=dims)
-
- self.model = model
-
- self.encoder_output_size = self.model.dims.n_audio_state
-
- def forward(self, ):
- pass
-
- def inference(self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
- if kwargs.get("batch_size", 1) > 1:
- raise NotImplementedError("batch decoding is not implemented")
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ **kwargs,
+ ):
+ is_pad_mask = kwargs.get("is_pad_mask", False)
- if frontend is None and not hasattr(self, "frontend"):
- frontend_class = tables.frontend_classes.get("WhisperFrontend")
- frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
- self.frontend = frontend
+ q = self.query(x)
+
+ if kv_cache is None or xa is None or self.key not in kv_cache:
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
+ k = self.key(x if xa is None else xa)
+ v = self.value(x if xa is None else xa)
else:
- frontend = frontend if frontend is not None else self.frontend
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
+ k = kv_cache[self.key]
+ v = kv_cache[self.value]
- meta_data = {}
- if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer)
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
- frontend=frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
- lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
+ wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
+ return self.out(wv), qk
- speech = speech.to(device=kwargs["device"])[0, :, :]
- speech_lengths = speech_lengths.to(device=kwargs["device"])
+ def qkv_attention(
+ self,
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ mask: Optional[Tensor] = None,
+ **kwargs,
+ ):
+ is_pad_mask = kwargs.get("is_pad_mask", False)
+ n_batch, n_ctx, n_state = q.shape
+ scale = (n_state // self.n_head) ** -0.25
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
- language = kwargs.get("language", None)
- initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
- # # detect the spoken language
- # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt)
- # print(f"Detected language: {max(probs, key=probs.get)}")
- # language = max(probs, key=probs.get)
- # language = language if kwargs.get("language", None) is None else kwargs.get("language")
-
- # decode the audio
- prompt = ""
- initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
- options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
- result = whisper.decode(self.model, speech, options)
+ qk = q @ k
+ if mask is not None:
+ if not is_pad_mask:
+ qk = qk + mask[:n_ctx, :n_ctx]
+ else:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = -float(
+ "inf"
+ ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
+ qk = qk.masked_fill(mask, min_value)
- results = []
- result_i = {"key": key[0], "text": result.text}
+ qk = qk.float()
- results.append(result_i)
-
- return results, meta_data
-
\ No newline at end of file
+ w = F.softmax(qk, dim=-1).to(q.dtype)
+ if mask is not None and is_pad_mask:
+ w = w.masked_fill(mask, 0.0)
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
+ super().__init__()
+
+ self.attn = MultiHeadAttention(n_state, n_head)
+ self.attn_ln = LayerNorm(n_state)
+
+ self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+
+ n_mlp = n_state * 4
+ self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
+ self.mlp_ln = LayerNorm(n_state)
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ **kwargs,
+ ):
+ is_pad_mask = kwargs.get("is_pad_mask", False)
+ is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+ if self.cross_attn:
+ x = (
+ x
+ + self.cross_attn(
+ self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
+ )[0]
+ )
+ x = x + self.mlp(self.mlp_ln(x))
+ return x
+
+
+class AudioEncoder(nn.Module):
+ def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
+ super().__init__()
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1)
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
+
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
+ )
+ self.ln_post = LayerNorm(n_state)
+
+ def forward(self, x: Tensor):
+ """
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
+ the mel spectrogram of the audio
+ """
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+ x = x.permute(0, 2, 1)
+
+ # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
+ # x = (x + self.positional_embedding).to(x.dtype)
+ x = (x + self.positional_embedding[: x.size(1), :]).to(x.dtype)
+
+ for block in self.blocks:
+ x = block(x)
+
+ x = self.ln_post(x)
+ return x
+
+
+class TextDecoder(nn.Module):
+ def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
+ super().__init__()
+
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+ [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
+ )
+ self.ln = LayerNorm(n_state)
+
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
+ self.register_buffer("mask", mask, persistent=False)
+
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
+ """
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
+ the text tokens
+ xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
+ the encoded audio features to be attended on
+ """
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+ x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
+ x = x.to(xa.dtype)
+
+ for block in self.blocks:
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
+
+ x = self.ln(x)
+ logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
+
+ return logits
+
+
+class Whisper(nn.Module):
+ def __init__(self, dims: ModelDimensions):
+ super().__init__()
+ self.dims = dims
+ self.encoder = AudioEncoder(
+ self.dims.n_mels,
+ self.dims.n_audio_ctx,
+ self.dims.n_audio_state,
+ self.dims.n_audio_head,
+ self.dims.n_audio_layer,
+ )
+ self.decoder = TextDecoder(
+ self.dims.n_vocab,
+ self.dims.n_text_ctx,
+ self.dims.n_text_state,
+ self.dims.n_text_head,
+ self.dims.n_text_layer,
+ )
+ # use the last half among the decoder layers for time alignment by default;
+ # to use a specific set of heads, see `set_alignment_heads()` below.
+ all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
+ all_heads[self.dims.n_text_layer // 2 :] = True
+ # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+ # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense()
+ # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False)
+
+ def set_alignment_heads(self, dump: bytes):
+ array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
+ mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
+ self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
+
+ def embed_audio(self, mel: torch.Tensor):
+ return self.encoder(mel)
+
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
+ return self.decoder(tokens, audio_features)
+
+ def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
+ return self.decoder(tokens, self.encoder(mel))
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def is_multilingual(self):
+ return self.dims.n_vocab >= 51865
+
+ @property
+ def num_languages(self):
+ return self.dims.n_vocab - 51765 - int(self.is_multilingual)
+
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
+ """
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
+ tensors calculated for the previous positions. This method returns a dictionary that stores
+ all caches, and the necessary hooks for the key and value projection modules that save the
+ intermediate tensors to be reused during later calculations.
+
+ Returns
+ -------
+ cache : Dict[nn.Module, torch.Tensor]
+ A dictionary object mapping the key/value projection modules to its cache
+ hooks : List[RemovableHandle]
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
+ """
+ cache = {**cache} if cache is not None else {}
+ hooks = []
+
+ def save_to_cache(module, _, output):
+ if module not in cache or output.shape[1] > self.dims.n_text_ctx:
+ # save as-is, for the first token or cross attention
+ cache[module] = output
+ else:
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
+ return cache[module]
+
+ def install_hooks(layer: nn.Module):
+ if isinstance(layer, MultiHeadAttention):
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
+
+ self.decoder.apply(install_hooks)
+ return cache, hooks
+
+ detect_language = detect_language_function
+ transcribe = transcribe_function
+ decode = decode_function
--
Gitblit v1.9.1