sense voice (#1568)
* train
* train
* train
* train
* train
* train
* train
* train
* train
* train
* train
* train
* train
* train
* train
* train
* train
* train
* train
* train
* whisper_lib for sense voice
* aishell recipe
* sense voice
| | |
| | | |
| | | def import_submodules(package, recursive=True): |
| | | if isinstance(package, str): |
| | | package = importlib.import_module(package) |
| | | try: |
| | | package = importlib.import_module(package) |
| | | except Exception as e: |
| | | # 如果想要看到导入错误的具体信息,可以取消注释下面的行 |
| | | # print(f"Failed to import {name}: {e}") |
| | | pass |
| | | results = {} |
| | | for loader, name, is_pkg in pkgutil.walk_packages(package.__path__, package.__name__ + '.'): |
| | | try: |
| New file |
| | |
| | | from dataclasses import dataclass |
| | | from typing import Dict |
| | | from typing import Iterable, Optional |
| | | import time |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import Tensor |
| | | from torch import nn |
| | | from . import whisper_lib as whisper |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("model_classes", "SenseVoice") |
| | | class SenseVoice(nn.Module): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__() |
| | | hub = kwargs.get("hub", "funasr") |
| | | |
| | | dims = kwargs.get("dims", {}) |
| | | dims = whisper.model.ModelDimensions(**dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | |
| | | self.model = model |
| | | |
| | | self.encoder_output_size = self.model.dims.n_audio_state |
| | | |
| | | def forward(self, ): |
| | | pass |
| | | |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | | frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)) |
| | | self.frontend = frontend |
| | | else: |
| | | frontend = frontend if frontend is not None else self.frontend |
| | | |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is None: |
| | | speech_lengths = speech.shape[1] |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 |
| | | lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1 |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000 |
| | | |
| | | speech = speech.to(device=kwargs["device"])[0, :, :] |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | language = kwargs.get("language", None) |
| | | initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>") |
| | | # # detect the spoken language |
| | | # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt) |
| | | # print(f"Detected language: {max(probs, key=probs.get)}") |
| | | # language = max(probs, key=probs.get) |
| | | # language = language if kwargs.get("language", None) is None else kwargs.get("language") |
| | | |
| | | # decode the audio |
| | | prompt = "" |
| | | initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>") |
| | | options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt) |
| | | result = whisper.decode(self.model, speech, options) |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": result.text} |
| | | |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | | |
| New file |
| | |
| | | # This is an example that demonstrates how to configure a model file. |
| | | # You can modify the configuration according to your own requirements. |
| | | |
| | | # to print the register_table: |
| | | # from funasr.register import tables |
| | | # tables.print() |
| | | |
| | | # network architecture |
| | | model: SenseVoice |
| | | model_conf: |
| | | lsm_weight: 0.1 |
| | | length_normalized_loss: true |
| | | hub: funasr |
| | | |
| | | |
| | | |
| | | # only use for hub == funasr, |
| | | # if hub == openai, dims is automaticall download |
| | | dims: |
| | | n_mels: 128 |
| | | n_vocab: 51866 |
| | | n_audio_ctx: 1500 |
| | | n_audio_state: 1280 |
| | | n_audio_head: 20 |
| | | n_audio_layer: 32 |
| | | n_text_ctx: 448 |
| | | n_text_state: 1280 |
| | | n_text_head: 20 |
| | | n_text_layer: 32 |
| | | |
| | | # frontend related |
| | | frontend: WhisperFrontend |
| | | frontend_conf: |
| | | fs: 16000 |
| | | n_mels: ${dims.n_mels} |
| | | do_pad_trim: true |
| | | |
| | | tokenizer: WhisperTokenizer |
| | | tokenizer_conf: |
| | | language: null |
| | | task: transcribe |
| | | is_multilingual: true |
| | | num_languages: 100 |
| | | |
| | | scope_map: [none, "model."] |
| | |
| | | |
| | | @torch.no_grad() |
| | | def detect_language( |
| | | model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None |
| | | model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, initial_prompt = None, x = None, |
| | | ) -> Tuple[Tensor, List[dict]]: |
| | | """ |
| | | Detect the spoken language in the audio, and return them as list of strings, along with the ids |
| | |
| | | mel = mel.unsqueeze(0) |
| | | |
| | | # skip encoder forward pass if already-encoded audio features were given |
| | | if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): |
| | | # FIX(funasr): sense vocie |
| | | if mel.shape[-1] != model.dims.n_audio_state: |
| | | mel = model.encoder(mel) |
| | | |
| | | # forward pass using a single token, startoftranscript |
| | | n_audio = mel.shape[0] |
| | | x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] |
| | | # FIX(funasr): sense vocie |
| | | # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] |
| | | if x is None: |
| | | x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device) # [n_audio, 1] |
| | | logits = model.logits(x, mel)[:, 0] |
| | | |
| | | # collect detected languages; suppress all non-language tokens |
| | |
| | | |
| | | # implementation details |
| | | fp16: bool = True # use fp16 for most of the calculation |
| | | |
| | | # FIX(funasr): sense vocie |
| | | initial_prompt: str = None |
| | | |
| | | |
| | | @dataclass(frozen=True) |
| | |
| | | + prompt_tokens[-(self.n_ctx // 2 - 1) :] |
| | | + tokens |
| | | ) |
| | | #FIX(gzf): sense vocie |
| | | if initial_prompt := self.options.initial_prompt: |
| | | tokens = self.tokenizer.encode(initial_prompt, allowed_special="all") |
| | | if self.options.language is None: |
| | | tokens += [0] |
| | | |
| | | |
| | | return tuple(tokens) |
| | | |
| | |
| | | |
| | | if self.options.language is None or self.options.task == "lang_id": |
| | | lang_tokens, lang_probs = self.model.detect_language( |
| | | audio_features, self.tokenizer |
| | | audio_features, self.tokenizer, x=tokens |
| | | ) |
| | | languages = [max(probs, key=probs.get) for probs in lang_probs] |
| | | # FIX(funasr): sense vocie |
| | | # if self.options.language is None: |
| | | # tokens[:, self.sot_index + 1] = lang_tokens # write language tokens |
| | | if self.options.language is None: |
| | | tokens[:, self.sot_index + 1] = lang_tokens # write language tokens |
| | | # tokens[:, self.sot_index + 1] = lang_tokens # write language tokens |
| | | languages = "".join([f"<|{language}|>" for language in languages]) |
| | | n_audio = audio_features.shape[0] |
| | | lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to( |
| | | audio_features.device) # [n_audio, 1] |
| | | |
| | | tokens[:, -1:] = lang_tokens[:, :] |
| | | languages = [languages] |
| | | |
| | | return languages, lang_probs |
| | | |
| | |
| | | import base64 |
| | | import gzip |
| | | from dataclasses import dataclass |
| | | from typing import Dict, Iterable, Optional |
| | | |
| | | from typing import Dict |
| | | from typing import Iterable, Optional |
| | | import time |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import Tensor, nn |
| | | from torch import Tensor |
| | | from torch import nn |
| | | from . import whisper_lib as whisper |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | from .decoding import decode as decode_function |
| | | from .decoding import detect_language as detect_language_function |
| | | from .transcribe import transcribe as transcribe_function |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @dataclass |
| | | class ModelDimensions: |
| | | n_mels: int |
| | | n_audio_ctx: int |
| | | n_audio_state: int |
| | | n_audio_head: int |
| | | n_audio_layer: int |
| | | n_vocab: int |
| | | n_text_ctx: int |
| | | n_text_state: int |
| | | n_text_head: int |
| | | n_text_layer: int |
| | | |
| | | |
| | | class LayerNorm(nn.LayerNorm): |
| | | def forward(self, x: Tensor) -> Tensor: |
| | | return super().forward(x.float()).type(x.dtype) |
| | | |
| | | |
| | | class Linear(nn.Linear): |
| | | def forward(self, x: Tensor) -> Tensor: |
| | | return F.linear( |
| | | x, |
| | | self.weight.to(x.dtype), |
| | | None if self.bias is None else self.bias.to(x.dtype), |
| | | ) |
| | | |
| | | |
| | | class Conv1d(nn.Conv1d): |
| | | def _conv_forward( |
| | | self, x: Tensor, weight: Tensor, bias: Optional[Tensor] |
| | | ) -> Tensor: |
| | | return super()._conv_forward( |
| | | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) |
| | | ) |
| | | |
| | | |
| | | def sinusoids(length, channels, max_timescale=10000): |
| | | """Returns sinusoids for positional embedding""" |
| | | assert channels % 2 == 0 |
| | | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
| | | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) |
| | | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
| | | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) |
| | | |
| | | |
| | | class MultiHeadAttention(nn.Module): |
| | | def __init__(self, n_state: int, n_head: int): |
| | | @tables.register("model_classes", "SenseVoice") |
| | | class SenseVoice(nn.Module): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__() |
| | | self.n_head = n_head |
| | | self.query = Linear(n_state, n_state) |
| | | self.key = Linear(n_state, n_state, bias=False) |
| | | self.value = Linear(n_state, n_state) |
| | | self.out = Linear(n_state, n_state) |
| | | hub = kwargs.get("hub", "funasr") |
| | | |
| | | def forward( |
| | | self, |
| | | x: Tensor, |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | ): |
| | | q = self.query(x) |
| | | dims = kwargs.get("dims", {}) |
| | | dims = whisper.model.ModelDimensions(**dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | |
| | | self.model = model |
| | | |
| | | self.encoder_output_size = self.model.dims.n_audio_state |
| | | |
| | | def forward(self, ): |
| | | pass |
| | | |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | if kv_cache is None or xa is None or self.key not in kv_cache: |
| | | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; |
| | | # otherwise, perform key/value projections for self- or cross-attention as usual. |
| | | k = self.key(x if xa is None else xa) |
| | | v = self.value(x if xa is None else xa) |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | | frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)) |
| | | self.frontend = frontend |
| | | else: |
| | | # for cross-attention, calculate keys and values once and reuse in subsequent calls. |
| | | k = kv_cache[self.key] |
| | | v = kv_cache[self.value] |
| | | frontend = frontend if frontend is not None else self.frontend |
| | | |
| | | wv, qk = self.qkv_attention(q, k, v, mask) |
| | | return self.out(wv), qk |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is None: |
| | | speech_lengths = speech.shape[1] |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 |
| | | lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1 |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000 |
| | | |
| | | def qkv_attention( |
| | | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None |
| | | ): |
| | | n_batch, n_ctx, n_state = q.shape |
| | | scale = (n_state // self.n_head) ** -0.25 |
| | | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale |
| | | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale |
| | | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) |
| | | speech = speech.to(device=kwargs["device"])[0, :, :] |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | qk = q @ k |
| | | if mask is not None: |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | qk = qk.float() |
| | | language = kwargs.get("language", None) |
| | | initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>") |
| | | # # detect the spoken language |
| | | # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt) |
| | | # print(f"Detected language: {max(probs, key=probs.get)}") |
| | | # language = max(probs, key=probs.get) |
| | | # language = language if kwargs.get("language", None) is None else kwargs.get("language") |
| | | |
| | | # decode the audio |
| | | prompt = "" |
| | | initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>") |
| | | options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt) |
| | | result = whisper.decode(self.model, speech, options) |
| | | |
| | | w = F.softmax(qk, dim=-1).to(q.dtype) |
| | | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() |
| | | results = [] |
| | | result_i = {"key": key[0], "text": result.text} |
| | | |
| | | |
| | | class ResidualAttentionBlock(nn.Module): |
| | | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): |
| | | super().__init__() |
| | | |
| | | self.attn = MultiHeadAttention(n_state, n_head) |
| | | self.attn_ln = LayerNorm(n_state) |
| | | |
| | | self.cross_attn = ( |
| | | MultiHeadAttention(n_state, n_head) if cross_attention else None |
| | | ) |
| | | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None |
| | | |
| | | n_mlp = n_state * 4 |
| | | self.mlp = nn.Sequential( |
| | | Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) |
| | | ) |
| | | self.mlp_ln = LayerNorm(n_state) |
| | | |
| | | def forward( |
| | | self, |
| | | x: Tensor, |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | ): |
| | | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] |
| | | if self.cross_attn: |
| | | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] |
| | | x = x + self.mlp(self.mlp_ln(x)) |
| | | return x |
| | | |
| | | |
| | | class AudioEncoder(nn.Module): |
| | | def __init__( |
| | | self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int |
| | | ): |
| | | super().__init__() |
| | | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) |
| | | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) |
| | | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) |
| | | |
| | | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( |
| | | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] |
| | | ) |
| | | self.ln_post = LayerNorm(n_state) |
| | | |
| | | def forward(self, x: Tensor): |
| | | """ |
| | | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) |
| | | the mel spectrogram of the audio |
| | | """ |
| | | x = F.gelu(self.conv1(x)) |
| | | x = F.gelu(self.conv2(x)) |
| | | x = x.permute(0, 2, 1) |
| | | |
| | | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" |
| | | x = (x + self.positional_embedding).to(x.dtype) |
| | | |
| | | for block in self.blocks: |
| | | x = block(x) |
| | | |
| | | x = self.ln_post(x) |
| | | return x |
| | | |
| | | |
| | | class TextDecoder(nn.Module): |
| | | def __init__( |
| | | self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int |
| | | ): |
| | | super().__init__() |
| | | |
| | | self.token_embedding = nn.Embedding(n_vocab, n_state) |
| | | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) |
| | | |
| | | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( |
| | | [ |
| | | ResidualAttentionBlock(n_state, n_head, cross_attention=True) |
| | | for _ in range(n_layer) |
| | | ] |
| | | ) |
| | | self.ln = LayerNorm(n_state) |
| | | |
| | | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) |
| | | self.register_buffer("mask", mask, persistent=False) |
| | | |
| | | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): |
| | | """ |
| | | x : torch.LongTensor, shape = (batch_size, <= n_ctx) |
| | | the text tokens |
| | | xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) |
| | | the encoded audio features to be attended on |
| | | """ |
| | | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 |
| | | x = ( |
| | | self.token_embedding(x) |
| | | + self.positional_embedding[offset : offset + x.shape[-1]] |
| | | ) |
| | | x = x.to(xa.dtype) |
| | | |
| | | for block in self.blocks: |
| | | x = block(x, xa, mask=self.mask, kv_cache=kv_cache) |
| | | |
| | | x = self.ln(x) |
| | | logits = ( |
| | | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) |
| | | ).float() |
| | | |
| | | return logits |
| | | |
| | | |
| | | class Whisper(nn.Module): |
| | | def __init__(self, dims: ModelDimensions): |
| | | super().__init__() |
| | | self.dims = dims |
| | | self.encoder = AudioEncoder( |
| | | self.dims.n_mels, |
| | | self.dims.n_audio_ctx, |
| | | self.dims.n_audio_state, |
| | | self.dims.n_audio_head, |
| | | self.dims.n_audio_layer, |
| | | ) |
| | | self.decoder = TextDecoder( |
| | | self.dims.n_vocab, |
| | | self.dims.n_text_ctx, |
| | | self.dims.n_text_state, |
| | | self.dims.n_text_head, |
| | | self.dims.n_text_layer, |
| | | ) |
| | | # use the last half among the decoder layers for time alignment by default; |
| | | # to use a specific set of heads, see `set_alignment_heads()` below. |
| | | all_heads = torch.zeros( |
| | | self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool |
| | | ) |
| | | all_heads[self.dims.n_text_layer // 2 :] = True |
| | | self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) |
| | | |
| | | def set_alignment_heads(self, dump: bytes): |
| | | array = np.frombuffer( |
| | | gzip.decompress(base64.b85decode(dump)), dtype=bool |
| | | ).copy() |
| | | mask = torch.from_numpy(array).reshape( |
| | | self.dims.n_text_layer, self.dims.n_text_head |
| | | ) |
| | | self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) |
| | | |
| | | def embed_audio(self, mel: torch.Tensor): |
| | | return self.encoder(mel) |
| | | |
| | | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): |
| | | return self.decoder(tokens, audio_features) |
| | | |
| | | def forward( |
| | | self, mel: torch.Tensor, tokens: torch.Tensor |
| | | ) -> Dict[str, torch.Tensor]: |
| | | return self.decoder(tokens, self.encoder(mel)) |
| | | |
| | | @property |
| | | def device(self): |
| | | return next(self.parameters()).device |
| | | |
| | | @property |
| | | def is_multilingual(self): |
| | | return self.dims.n_vocab >= 51865 |
| | | |
| | | @property |
| | | def num_languages(self): |
| | | return self.dims.n_vocab - 51765 - int(self.is_multilingual) |
| | | |
| | | def install_kv_cache_hooks(self, cache: Optional[dict] = None): |
| | | """ |
| | | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value |
| | | tensors calculated for the previous positions. This method returns a dictionary that stores |
| | | all caches, and the necessary hooks for the key and value projection modules that save the |
| | | intermediate tensors to be reused during later calculations. |
| | | |
| | | Returns |
| | | ------- |
| | | cache : Dict[nn.Module, torch.Tensor] |
| | | A dictionary object mapping the key/value projection modules to its cache |
| | | hooks : List[RemovableHandle] |
| | | List of PyTorch RemovableHandle objects to stop the hooks to be called |
| | | """ |
| | | cache = {**cache} if cache is not None else {} |
| | | hooks = [] |
| | | |
| | | def save_to_cache(module, _, output): |
| | | if module not in cache or output.shape[1] > self.dims.n_text_ctx: |
| | | # save as-is, for the first token or cross attention |
| | | cache[module] = output |
| | | else: |
| | | cache[module] = torch.cat([cache[module], output], dim=1).detach() |
| | | return cache[module] |
| | | |
| | | def install_hooks(layer: nn.Module): |
| | | if isinstance(layer, MultiHeadAttention): |
| | | hooks.append(layer.key.register_forward_hook(save_to_cache)) |
| | | hooks.append(layer.value.register_forward_hook(save_to_cache)) |
| | | |
| | | self.decoder.apply(install_hooks) |
| | | return cache, hooks |
| | | |
| | | detect_language = detect_language_function |
| | | transcribe = transcribe_function |
| | | decode = decode_function |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | | |
| | |
| | | |
| | | import tiktoken |
| | | |
| | | # FIX(funasr): sense vocie |
| | | LANGUAGES = { |
| | | "en": "english", |
| | | "zh": "chinese", |
| | |
| | | "jw": "javanese", |
| | | "su": "sundanese", |
| | | "yue": "cantonese", |
| | | "minnan": "minnan", |
| | | "wuyu": "wuyu", |
| | | "dialect": "dialect", |
| | | "zh/en": "zh/en", |
| | | "en/zh": "en/zh", |
| | | } |
| | | |
| | | # language code lookup by name, with a few language aliases |
| | |
| | | "sinhalese": "si", |
| | | "castilian": "es", |
| | | "mandarin": "zh", |
| | | } |
| | | |
| | | # FIX(funasr): sense vocie |
| | | AUDIO_EVENT = { |
| | | "ASR": "ASR", |
| | | "AED": "AED", |
| | | "SER": "SER", |
| | | "Speech": "Speech", |
| | | "/Speech": "/Speech", |
| | | "BGM": "BGM", |
| | | "/BGM": "/BGM", |
| | | "Laughter": "Laughter", |
| | | "/Laughter": "/Laughter", |
| | | "Applause": "Applause", |
| | | "/Applause": "/Applause", |
| | | } |
| | | |
| | | EMOTION = { |
| | | "HAPPY": "HAPPY", |
| | | "SAD": "SAD", |
| | | "ANGRY": "ANGRY", |
| | | "NEUTRAL": "NEUTRAL", |
| | | } |
| | | |
| | | |
| | |
| | | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". |
| | | """ |
| | | return self.encoding.decode(token_ids, **kwargs) |
| | | |
| | | def get_vocab_size(self) -> int: |
| | | return self.encoding.n_vocab |
| | | |
| | | @cached_property |
| | | def eot(self) -> int: |
| | |
| | | |
| | | @cached_property |
| | | def sot(self) -> int: |
| | | return self.special_tokens["<|startoftranscript|>"] |
| | | |
| | | @cached_property |
| | | def sot_sense(self) -> int: |
| | | return self.special_tokens["<|startoftranscript|>"] |
| | | |
| | | @cached_property |
| | |
| | | n_vocab = len(ranks) |
| | | special_tokens = {} |
| | | |
| | | specials = [ |
| | | "<|endoftext|>", |
| | | "<|startoftranscript|>", |
| | | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], |
| | | "<|translate|>", |
| | | "<|transcribe|>", |
| | | "<|startoflm|>", |
| | | "<|startofprev|>", |
| | | "<|nospeech|>", |
| | | "<|notimestamps|>", |
| | | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], |
| | | ] |
| | | if False: #name == "gpt2" or name == "multilingual": |
| | | specials = [ |
| | | "<|endoftext|>", |
| | | "<|startoftranscript|>", |
| | | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], |
| | | "<|translate|>", |
| | | "<|transcribe|>", |
| | | "<|startoflm|>", |
| | | "<|startofprev|>", |
| | | "<|nospeech|>", |
| | | "<|notimestamps|>", |
| | | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], |
| | | ] |
| | | else: |
| | | specials = [ |
| | | "<|endoftext|>", |
| | | "<|startoftranscript|>", |
| | | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], |
| | | *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())], |
| | | *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())], |
| | | "<|translate|>", |
| | | "<|transcribe|>", |
| | | "<|startoflm|>", |
| | | "<|startofprev|>", |
| | | "<|nospeech|>", |
| | | "<|notimestamps|>", |
| | | *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 51)], |
| | | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], |
| | | ] |
| | | |
| | | for token in specials: |
| | | special_tokens[token] = n_vocab |
| | |
| | | num_languages: int = 99, |
| | | language: Optional[str] = None, |
| | | task: Optional[str] = None, # Literal["transcribe", "translate", None] |
| | | encoding_path: Optional[str] = None, |
| | | ) -> Tokenizer: |
| | | if language is not None: |
| | | language = language.lower() |
| | |
| | | encoding_name = "gpt2" |
| | | language = None |
| | | task = None |
| | | if encoding_path is not None: |
| | | encoding_name = encoding_path |
| | | |
| | | encoding = get_encoding(name=encoding_name, num_languages=num_languages) |
| | | |
| | |
| | | "hydra-core>=1.3.2", |
| | | "tensorboardX", |
| | | "rotary_embedding_torch", |
| | | "openai-whisper", |
| | | ], |
| | | # train: The modules invoked when training only. |
| | | "train": [ |