| | |
| | | from dataclasses import dataclass |
| | | from typing import Dict |
| | | from typing import Iterable, Optional |
| | | |
| | | import time |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import Tensor |
| | | from torch import nn |
| | | import whisper |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | |
| | | from funasr.models.whisper.utils.decoding import detect_language as detect_language_function, decode as decode_function |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @dataclass |
| | | class ModelDimensions: |
| | | n_mels: int |
| | | n_audio_ctx: int |
| | | n_audio_state: int |
| | | n_audio_head: int |
| | | n_audio_layer: int |
| | | n_vocab: int |
| | | n_text_ctx: int |
| | | n_text_state: int |
| | | n_text_head: int |
| | | n_text_layer: int |
| | | |
| | | |
| | | class LayerNorm(nn.LayerNorm): |
| | | def forward(self, x: Tensor) -> Tensor: |
| | | return super().forward(x.float()).type(x.dtype) |
| | | |
| | | |
| | | class Linear(nn.Linear): |
| | | def forward(self, x: Tensor) -> Tensor: |
| | | return F.linear( |
| | | x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) |
| | | ) |
| | | |
| | | |
| | | class Conv1d(nn.Conv1d): |
| | | def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: |
| | | return super()._conv_forward( |
| | | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) |
| | | ) |
| | | |
| | | |
| | | def sinusoids(length, channels, max_timescale=10000): |
| | | """Returns sinusoids for positional embedding""" |
| | | assert channels % 2 == 0 |
| | | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
| | | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) |
| | | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
| | | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) |
| | | |
| | | |
| | | class MultiHeadAttention(nn.Module): |
| | | def __init__(self, n_state: int, n_head: int): |
| | | @tables.register("model_classes", "WhisperWarp") |
| | | class WhisperWarp(nn.Module): |
| | | def __init__(self, whisper_dims: dict, **kwargs): |
| | | super().__init__() |
| | | self.n_head = n_head |
| | | self.query = Linear(n_state, n_state) |
| | | self.key = Linear(n_state, n_state, bias=False) |
| | | self.value = Linear(n_state, n_state) |
| | | self.out = Linear(n_state, n_state) |
| | | |
| | | def forward( |
| | | self, |
| | | x: Tensor, |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | ): |
| | | q = self.query(x) |
| | | |
| | | if kv_cache is None or xa is None or self.key not in kv_cache: |
| | | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; |
| | | # otherwise, perform key/value projections for self- or cross-attention as usual. |
| | | k = self.key(x if xa is None else xa) |
| | | v = self.value(x if xa is None else xa) |
| | | hub = kwargs.get("hub", "funasr") |
| | | if hub == "openai": |
| | | init_param_path = kwargs.get("init_param_path", "large-v3") |
| | | model = whisper.load_model(init_param_path) |
| | | else: |
| | | # for cross-attention, calculate keys and values once and reuse in subsequent calls. |
| | | k = kv_cache[self.key] |
| | | v = kv_cache[self.value] |
| | | dims = whisper.model.ModelDimensions(**whisper_dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | |
| | | wv, qk = self.qkv_attention(q, k, v, mask) |
| | | return self.out(wv), qk |
| | | self.model = model |
| | | |
| | | def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): |
| | | n_batch, n_ctx, n_state = q.shape |
| | | scale = (n_state // self.n_head) ** -0.25 |
| | | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale |
| | | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale |
| | | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) |
| | | def forward(self, ): |
| | | pass |
| | | |
| | | qk = q @ k |
| | | if mask is not None: |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | qk = qk.float() |
| | | |
| | | w = F.softmax(qk, dim=-1).to(q.dtype) |
| | | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() |
| | | |
| | | |
| | | class ResidualAttentionBlock(nn.Module): |
| | | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): |
| | | super().__init__() |
| | | |
| | | self.attn = MultiHeadAttention(n_state, n_head) |
| | | self.attn_ln = LayerNorm(n_state) |
| | | |
| | | self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None |
| | | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None |
| | | |
| | | n_mlp = n_state * 4 |
| | | self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) |
| | | self.mlp_ln = LayerNorm(n_state) |
| | | |
| | | def forward( |
| | | self, |
| | | x: Tensor, |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] |
| | | if self.cross_attn: |
| | | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] |
| | | x = x + self.mlp(self.mlp_ln(x)) |
| | | return x |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | |
| | | |
| | | @tables.register("encoder_classes", "WhisperEncoder") |
| | | class AudioEncoder(nn.Module): |
| | | def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): |
| | | super().__init__() |
| | | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) |
| | | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) |
| | | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) |
| | | |
| | | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( |
| | | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] |
| | | ) |
| | | self.ln_post = LayerNorm(n_state) |
| | | |
| | | def forward(self, x: Tensor): |
| | | """ |
| | | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) |
| | | the mel spectrogram of the audio |
| | | """ |
| | | x = F.gelu(self.conv1(x)) |
| | | x = F.gelu(self.conv2(x)) |
| | | x = x.permute(0, 2, 1) |
| | | |
| | | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" |
| | | x = (x + self.positional_embedding).to(x.dtype) |
| | | |
| | | for block in self.blocks: |
| | | x = block(x) |
| | | |
| | | x = self.ln_post(x) |
| | | return x |
| | | |
| | | @tables.register("decoder_classes", "WhisperDecoder") |
| | | class TextDecoder(nn.Module): |
| | | def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): |
| | | super().__init__() |
| | | |
| | | self.token_embedding = nn.Embedding(n_vocab, n_state) |
| | | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) |
| | | |
| | | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( |
| | | [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)] |
| | | ) |
| | | self.ln = LayerNorm(n_state) |
| | | |
| | | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) |
| | | self.register_buffer("mask", mask, persistent=False) |
| | | |
| | | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): |
| | | """ |
| | | x : torch.LongTensor, shape = (batch_size, <= n_ctx) |
| | | the text tokens |
| | | xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) |
| | | the encoded audio features to be attended on |
| | | """ |
| | | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| | | x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] |
| | | x = x.to(xa.dtype) |
| | | |
| | | for block in self.blocks: |
| | | x = block(x, xa, mask=self.mask, kv_cache=kv_cache) |
| | | |
| | | x = self.ln(x) |
| | | logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() |
| | | |
| | | return logits |
| | | |
| | | @tables.register("model_classes", "Whisper") |
| | | class Whisper(nn.Module): |
| | | def __init__(self, dims: dict): |
| | | super().__init__() |
| | | dims = ModelDimensions(**dims) |
| | | self.dims = dims |
| | | self.sos = 1 |
| | | self.eos = 1 |
| | | self.encoder = AudioEncoder( |
| | | self.dims.n_mels, |
| | | self.dims.n_audio_ctx, |
| | | self.dims.n_audio_state, |
| | | self.dims.n_audio_head, |
| | | self.dims.n_audio_layer, |
| | | ) |
| | | self.decoder = TextDecoder( |
| | | self.dims.n_vocab, |
| | | self.dims.n_text_ctx, |
| | | self.dims.n_text_state, |
| | | self.dims.n_text_head, |
| | | self.dims.n_text_layer, |
| | | ) |
| | | |
| | | def embed_audio(self, mel: torch.Tensor): |
| | | return self.encoder(mel) |
| | | |
| | | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): |
| | | return self.decoder(tokens, audio_features) |
| | | |
| | | def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | | return self.decoder(tokens, self.encoder(mel)) |
| | | |
| | | @property |
| | | def device(self): |
| | | return next(self.parameters()).device |
| | | |
| | | @property |
| | | def is_multilingual(self): |
| | | return self.dims.n_vocab == 51865 |
| | | |
| | | def install_kv_cache_hooks(self, cache: Optional[dict] = None): |
| | | """ |
| | | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value |
| | | tensors calculated for the previous positions. This method returns a dictionary that stores |
| | | all caches, and the necessary hooks for the key and value projection modules that save the |
| | | intermediate tensors to be reused during later calculations. |
| | | |
| | | Returns |
| | | ------- |
| | | cache : Dict[nn.Module, torch.Tensor] |
| | | A dictionary object mapping the key/value projection modules to its cache |
| | | hooks : List[RemovableHandle] |
| | | List of PyTorch RemovableHandle objects to stop the hooks to be called |
| | | """ |
| | | cache = {**cache} if cache is not None else {} |
| | | hooks = [] |
| | | |
| | | def save_to_cache(module, _, output): |
| | | if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]: |
| | | cache[module] = output # save as-is, for the first token or cross attention |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is None: |
| | | speech_lengths = speech.shape[1] |
| | | else: |
| | | cache[module] = torch.cat([cache[module], output], dim=1).detach() |
| | | return cache[module] |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 |
| | | lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1 |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000 |
| | | |
| | | def install_hooks(layer: nn.Module): |
| | | if isinstance(layer, MultiHeadAttention): |
| | | hooks.append(layer.key.register_forward_hook(save_to_cache)) |
| | | hooks.append(layer.value.register_forward_hook(save_to_cache)) |
| | | speech = speech.to(device=kwargs["device"])[0, :, :] |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | self.decoder.apply(install_hooks) |
| | | return cache, hooks |
| | | # detect the spoken language |
| | | _, probs = self.model.detect_language(speech) |
| | | print(f"Detected language: {max(probs, key=probs.get)}") |
| | | |
| | | detect_language = detect_language_function |
| | | decode = decode_function |
| | | # decode the audio |
| | | options = whisper.DecodingOptions(language=kwargs.get("language", None), fp16=False) |
| | | result = whisper.decode(self.model, speech, options) |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": result.text} |
| | | |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | | |