zhifu gao
2024-04-02 48693b45c021a842ea964c9dc99479b61eac062f
Dev gzf new (#1574)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* whisper_lib for sense voice

* aishell recipe

* sense voice

* docs
8个文件已修改
6个文件已添加
513 ■■■■ 已修改文件
examples/aishell/branchformer/README.md 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/branchformer/run.sh 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/conformer/README.md 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_branchformer/README.md 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_branchformer/run.sh 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/paraformer/README.md 24 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/transformer/README.md 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/transformer/run.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer-zh-spk/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/sense_voice/demo.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/model.py 379 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/branchformer/README.md
New file
@@ -0,0 +1,14 @@
# Branchformer Result
## Training Config
- Feature info: using raw speech, extracting 80 dims fbank online, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
- Train info: lr 0.001, batch_size 10000, 4 gpu(Tesla V100), acc_grad 1, 180 epochs
- Train config: conf/train_asr_branchformer.yaml
- LM config: LM was not used
## Results (CER)
|   testset   | CER(%)  |
|:-----------:|:-------:|
|     dev     |  4.15   |
|    test     |  4.51   |
examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml
@@ -79,8 +79,9 @@
train_conf:
  accum_grad: 1
  grad_clip: 5
  max_epoch: 150
  max_epoch: 180
  keep_nbest_models: 10
  avg_keep_nbest_models_type: acc
  log_interval: 50
optim: adam
@@ -96,7 +97,7 @@
    index_ds: IndexDSJsonl
    batch_sampler: EspnetStyleBatchSampler
    batch_type: length # example or length
    batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    batch_size: 10000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
    buffer_size: 1024
    shuffle: True
@@ -116,3 +117,6 @@
    reduce: true
    ignore_nan_grad: true
normalize: null
beam_size: 10
decoding_ctc_weight: 0.4
examples/aishell/branchformer/run.sh
@@ -1,7 +1,7 @@
#!/usr/bin/env bash
CUDA_VISIBLE_DEVICES="0,1"
CUDA_VISIBLE_DEVICES="0,1,2,3"
# general configuration
feats_dir="../DATA" #feature output dictionary
@@ -17,7 +17,7 @@
inference_device="cuda" #"cpu"
inference_checkpoint="model.pt.avg10"
inference_scp="wav.scp"
inference_batch_size=32
inference_batch_size=1
# data
raw_data=../raw_data
examples/aishell/conformer/README.md
New file
@@ -0,0 +1,16 @@
# Conformer Result
## Training Config
- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
- Train config: conf/train_asr_transformer.yaml
- LM config: LM was not used
- Model size: 46M
## Results (CER)
|   testset   | CER(%)  |
|:-----------:|:-------:|
|     dev     |  4.42   |
|    test     |  4.87   |
examples/aishell/e_branchformer/README.md
New file
@@ -0,0 +1,14 @@
# E-Branchformer Result
## Training Config
- Feature info: using raw speech, extracting 80 dims fbank online, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
- Train info: lr 0.001, batch_size 10000, 4 gpu(Tesla V100), acc_grad 1, 180 epochs
- Train config: conf/train_asr_e_branchformer.yaml
- LM config: LM was not used
## Results (CER)
|   testset   | CER(%)  |
|:-----------:|:-------:|
|     dev     |  4.10   |
|    test     |  4.52   |
examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml
@@ -96,7 +96,7 @@
    index_ds: IndexDSJsonl
    batch_sampler: EspnetStyleBatchSampler
    batch_type: length # example or length
    batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    batch_size: 10000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
    buffer_size: 1024
    shuffle: True
examples/aishell/e_branchformer/run.sh
@@ -1,7 +1,7 @@
#!/usr/bin/env bash
CUDA_VISIBLE_DEVICES="0,1"
CUDA_VISIBLE_DEVICES="0,1,2,3"
# general configuration
feats_dir="../DATA" #feature output dictionary
@@ -17,7 +17,7 @@
inference_device="cuda" #"cpu"
inference_checkpoint="model.pt.avg10"
inference_scp="wav.scp"
inference_batch_size=32
inference_batch_size=1
# data
raw_data=../raw_data
examples/aishell/paraformer/README.md
New file
@@ -0,0 +1,24 @@
# Paraformer
pretrained model in [ModelScope](https://www.modelscope.cn/home):[speech_paraformer_asr_nat-aishell1-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary)
## Training Config
- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
- Train config: conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
- LM config: LM was not used
## Results (CER)
- Decode config: conf/decode_asr_transformer_noctc_1best.yaml (ctc weight:0.0)
|   testset   | CER(%)  |
|:-----------:|:-------:|
|     dev     |  4.66   |
|    test     |  5.11   |
- Decode config: conf/decode_asr_transformer.yaml (ctc weight:0.5)
|   testset   | CER(%)  |
|:-----------:|:-------:|
|     dev     |  4.52   |
|    test     |  4.94   |
examples/aishell/transformer/README.md
New file
@@ -0,0 +1,16 @@
# Conformer Result
## Training Config
- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
- Train config: conf/train_asr_transformer.yaml
- LM config: LM was not used
- Model size: 46M
## Results (CER)
|   testset   | CER(%) |
|:-----------:|:------:|
|     dev     |  4.97  |
|    test     |  5.37  |
examples/aishell/transformer/run.sh
@@ -17,7 +17,7 @@
inference_device="cuda" #"cpu"
inference_checkpoint="model.pt.avg10"
inference_scp="wav.scp"
inference_batch_size=32
inference_batch_size=1
# data
raw_data=../raw_data
examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
@@ -8,7 +8,7 @@
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  # spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
                  spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
                  )
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
examples/industrial_data_pretraining/sense_voice/demo.py
New file
@@ -0,0 +1,16 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoice",
                  vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  vad_kwargs={"max_single_segment_time": 30000},
                  )
task = "ASR"
language = None
input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
res = model.generate(task=task, language=language, input=input_wav, batch_size_s=0,)
print(res)
funasr/models/sense_voice/model.py
@@ -74,8 +74,14 @@
        speech = speech.to(device=kwargs["device"])[0, :, :]
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        task = kwargs.get("task", "ASR")
        if isinstance(task, str):
            task = [task]
        task = "".join([f"<|{x}|>" for x in task])
        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
        language = kwargs.get("language", None)
        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
        language = None if language == "auto" else language
        # if language is None:
        # # detect the spoken language
        # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt)
        # print(f"Detected language: {max(probs, key=probs.get)}")
@@ -83,8 +89,8 @@
        # language = language if kwargs.get("language", None) is None else kwargs.get("language")
        
        # decode the audio
        prompt = ""
        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
        # initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
        options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
        result = whisper.decode(self.model, speech, options)
funasr/models/sense_voice/whisper_lib/model.py
@@ -1,97 +1,316 @@
import base64
import gzip
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
import time
from typing import Dict, Iterable, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from . import whisper_lib as whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from torch import Tensor, nn
from funasr.register import tables
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function
@tables.register("model_classes", "SenseVoice")
class SenseVoice(nn.Module):
    def __init__(self, *args, **kwargs):
@dataclass
class ModelDimensions:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int
class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )
class Conv1d(nn.Conv1d):
    def _conv_forward(
        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
    ) -> Tensor:
        return super()._conv_forward(
            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
        )
def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        hub = kwargs.get("hub", "funasr")
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)
        dims = kwargs.get("dims", {})
        dims = whisper.model.ModelDimensions(**dims)
        model = whisper.model.Whisper(dims=dims)
        self.model = model
        self.encoder_output_size = self.model.dims.n_audio_state
    def forward(self, ):
        pass
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  **kwargs,
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
                  ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        q = self.query(x)
        if frontend is None and not hasattr(self, "frontend"):
            frontend_class = tables.frontend_classes.get("WhisperFrontend")
            frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
            self.frontend = frontend
        if kv_cache is None or xa is None or self.key not in kv_cache:
            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
            # otherwise, perform key/value projections for self- or cross-attention as usual.
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            frontend = frontend if frontend is not None else self.frontend
            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
            k = kv_cache[self.key]
            v = kv_cache[self.value]
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk
    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
        qk = qk.float()
        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()
        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)
        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x
class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)
    def forward(self, x: Tensor):
        """
        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            the mel spectrogram of the audio
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)
        # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
        # x = (x + self.positional_embedding).to(x.dtype)
        x = (x + self.positional_embedding[: x.size(1), :]).to(x.dtype)
        for block in self.blocks:
            x = block(x)
        x = self.ln_post(x)
        return x
class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)
        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)
    def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
        """
        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            the text tokens
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            the encoded audio features to be attended on
        """
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)
        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
        x = self.ln(x)
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()
        return logits
class Whisper(nn.Module):
    def __init__(self, dims: ModelDimensions):
        super().__init__()
        self.dims = dims
        self.encoder = AudioEncoder(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )
        self.decoder = TextDecoder(
            self.dims.n_vocab,
            self.dims.n_text_ctx,
            self.dims.n_text_state,
            self.dims.n_text_head,
            self.dims.n_text_layer,
        )
        # use the last half among the decoder layers for time alignment by default;
        # to use a specific set of heads, see `set_alignment_heads()` below.
        all_heads = torch.zeros(
            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
        )
        all_heads[self.dims.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
    def set_alignment_heads(self, dump: bytes):
        array = np.frombuffer(
            gzip.decompress(base64.b85decode(dump)), dtype=bool
        ).copy()
        mask = torch.from_numpy(array).reshape(
            self.dims.n_text_layer, self.dims.n_text_head
        )
        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
    def embed_audio(self, mel: torch.Tensor):
        return self.encoder(mel)
    def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
        return self.decoder(tokens, audio_features)
    def forward(
        self, mel: torch.Tensor, tokens: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        return self.decoder(tokens, self.encoder(mel))
    @property
    def device(self):
        return next(self.parameters()).device
    @property
    def is_multilingual(self):
        return self.dims.n_vocab >= 51865
    @property
    def num_languages(self):
        return self.dims.n_vocab - 51765 - int(self.is_multilingual)
    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        """
        The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
        tensors calculated for the previous positions. This method returns a dictionary that stores
        all caches, and the necessary hooks for the key and value projection modules that save the
        intermediate tensors to be reused during later calculations.
        Returns
        -------
        cache : Dict[nn.Module, torch.Tensor]
            A dictionary object mapping the key/value projection modules to its cache
        hooks : List[RemovableHandle]
            List of PyTorch RemovableHandle objects to stop the hooks to be called
        """
        cache = {**cache} if cache is not None else {}
        hooks = []
        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
                # save as-is, for the first token or cross attention
                cache[module] = output
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
            lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]
        speech = speech.to(device=kwargs["device"])[0, :, :]
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))
        language = kwargs.get("language", None)
        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
        # # detect the spoken language
        # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt)
        # print(f"Detected language: {max(probs, key=probs.get)}")
        # language = max(probs, key=probs.get)
        # language = language if kwargs.get("language", None) is None else kwargs.get("language")
        self.decoder.apply(install_hooks)
        return cache, hooks
        
        # decode the audio
        prompt = ""
        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
        options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
        result = whisper.decode(self.model, speech, options)
        results = []
        result_i = {"key": key[0], "text": result.text}
        results.append(result_i)
        return results, meta_data
    detect_language = detect_language_function
    transcribe = transcribe_function
    decode = decode_function