From 48693b45c021a842ea964c9dc99479b61eac062f Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 02 四月 2024 10:33:27 +0800
Subject: [PATCH] Dev gzf new (#1574)

---
 examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml |    2 
 examples/industrial_data_pretraining/paraformer-zh-spk/demo.py           |    2 
 examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml     |    8 
 examples/aishell/paraformer/README.md                                    |   24 ++
 examples/aishell/transformer/README.md                                   |   16 +
 examples/aishell/branchformer/run.sh                                     |    4 
 examples/aishell/conformer/README.md                                     |   16 +
 examples/aishell/e_branchformer/run.sh                                   |    4 
 examples/industrial_data_pretraining/sense_voice/demo.py                 |   16 +
 funasr/models/sense_voice/model.py                                       |   24 +-
 examples/aishell/branchformer/README.md                                  |   14 +
 examples/aishell/e_branchformer/README.md                                |   14 +
 funasr/models/sense_voice/whisper_lib/model.py                           |  381 +++++++++++++++++++++++++++++++++---------
 examples/aishell/transformer/run.sh                                      |    2 
 14 files changed, 428 insertions(+), 99 deletions(-)

diff --git a/examples/aishell/branchformer/README.md b/examples/aishell/branchformer/README.md
new file mode 100644
index 0000000..930e58f
--- /dev/null
+++ b/examples/aishell/branchformer/README.md
@@ -0,0 +1,14 @@
+# Branchformer Result
+
+## Training Config
+- Feature info: using raw speech, extracting 80 dims fbank online, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train info: lr 0.001, batch_size 10000, 4 gpu(Tesla V100), acc_grad 1, 180 epochs
+- Train config: conf/train_asr_branchformer.yaml
+- LM config: LM was not used
+
+## Results (CER)
+
+|   testset   | CER(%)  |
+|:-----------:|:-------:|
+|     dev     |  4.15   |
+|    test     |  4.51   |
\ No newline at end of file
diff --git a/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml b/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml
index aefd2b9..acb7946 100644
--- a/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml
+++ b/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml
@@ -79,8 +79,9 @@
 train_conf:
   accum_grad: 1
   grad_clip: 5
-  max_epoch: 150
+  max_epoch: 180
   keep_nbest_models: 10
+  avg_keep_nbest_models_type: acc
   log_interval: 50
 
 optim: adam
@@ -96,7 +97,7 @@
     index_ds: IndexDSJsonl
     batch_sampler: EspnetStyleBatchSampler
     batch_type: length # example or length
-    batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    batch_size: 10000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
     buffer_size: 1024
     shuffle: True
@@ -116,3 +117,6 @@
     reduce: true
     ignore_nan_grad: true
 normalize: null
+
+beam_size: 10
+decoding_ctc_weight: 0.4
\ No newline at end of file
diff --git a/examples/aishell/branchformer/run.sh b/examples/aishell/branchformer/run.sh
index f7dda1c..918aa9b 100755
--- a/examples/aishell/branchformer/run.sh
+++ b/examples/aishell/branchformer/run.sh
@@ -1,7 +1,7 @@
 #!/usr/bin/env bash
 
 
-CUDA_VISIBLE_DEVICES="0,1"
+CUDA_VISIBLE_DEVICES="0,1,2,3"
 
 # general configuration
 feats_dir="../DATA" #feature output dictionary
@@ -17,7 +17,7 @@
 inference_device="cuda" #"cpu"
 inference_checkpoint="model.pt.avg10"
 inference_scp="wav.scp"
-inference_batch_size=32
+inference_batch_size=1
 
 # data
 raw_data=../raw_data
diff --git a/examples/aishell/conformer/README.md b/examples/aishell/conformer/README.md
new file mode 100644
index 0000000..003cbac
--- /dev/null
+++ b/examples/aishell/conformer/README.md
@@ -0,0 +1,16 @@
+
+# Conformer Result
+
+## Training Config
+- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
+- Train config: conf/train_asr_transformer.yaml
+- LM config: LM was not used
+- Model size: 46M
+
+## Results (CER)
+
+|   testset   | CER(%)  |
+|:-----------:|:-------:|
+|     dev     |  4.42   |
+|    test     |  4.87   |
\ No newline at end of file
diff --git a/examples/aishell/e_branchformer/README.md b/examples/aishell/e_branchformer/README.md
new file mode 100644
index 0000000..ac0aabb
--- /dev/null
+++ b/examples/aishell/e_branchformer/README.md
@@ -0,0 +1,14 @@
+# E-Branchformer Result
+
+## Training Config
+- Feature info: using raw speech, extracting 80 dims fbank online, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train info: lr 0.001, batch_size 10000, 4 gpu(Tesla V100), acc_grad 1, 180 epochs
+- Train config: conf/train_asr_e_branchformer.yaml
+- LM config: LM was not used
+
+## Results (CER)
+
+|   testset   | CER(%)  |
+|:-----------:|:-------:|
+|     dev     |  4.10   |
+|    test     |  4.52   |
\ No newline at end of file
diff --git a/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml b/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml
index 28d8e94..6438ae1 100644
--- a/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml
+++ b/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml
@@ -96,7 +96,7 @@
     index_ds: IndexDSJsonl
     batch_sampler: EspnetStyleBatchSampler
     batch_type: length # example or length
-    batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    batch_size: 10000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
     max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
     buffer_size: 1024
     shuffle: True
diff --git a/examples/aishell/e_branchformer/run.sh b/examples/aishell/e_branchformer/run.sh
index bc78b5f..be18599 100755
--- a/examples/aishell/e_branchformer/run.sh
+++ b/examples/aishell/e_branchformer/run.sh
@@ -1,7 +1,7 @@
 #!/usr/bin/env bash
 
 
-CUDA_VISIBLE_DEVICES="0,1"
+CUDA_VISIBLE_DEVICES="0,1,2,3"
 
 # general configuration
 feats_dir="../DATA" #feature output dictionary
@@ -17,7 +17,7 @@
 inference_device="cuda" #"cpu"
 inference_checkpoint="model.pt.avg10"
 inference_scp="wav.scp"
-inference_batch_size=32
+inference_batch_size=1
 
 # data
 raw_data=../raw_data
diff --git a/examples/aishell/paraformer/README.md b/examples/aishell/paraformer/README.md
new file mode 100644
index 0000000..c0385db
--- /dev/null
+++ b/examples/aishell/paraformer/README.md
@@ -0,0 +1,24 @@
+# Paraformer
+pretrained model in [ModelScope](https://www.modelscope.cn/home)锛歔speech_paraformer_asr_nat-aishell1-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary)
+
+## Training Config
+- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
+- Train config: conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
+- LM config: LM was not used
+
+## Results (CER)
+
+- Decode config: conf/decode_asr_transformer_noctc_1best.yaml (ctc weight:0.0)
+
+|   testset   | CER(%)  |
+|:-----------:|:-------:|
+|     dev     |  4.66   |
+|    test     |  5.11   |
+
+- Decode config: conf/decode_asr_transformer.yaml (ctc weight:0.5)
+
+|   testset   | CER(%)  |
+|:-----------:|:-------:|
+|     dev     |  4.52   |
+|    test     |  4.94   |
\ No newline at end of file
diff --git a/examples/aishell/transformer/README.md b/examples/aishell/transformer/README.md
new file mode 100644
index 0000000..2435b55
--- /dev/null
+++ b/examples/aishell/transformer/README.md
@@ -0,0 +1,16 @@
+
+# Conformer Result
+
+## Training Config
+- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
+- Train config: conf/train_asr_transformer.yaml
+- LM config: LM was not used
+- Model size: 46M
+
+## Results (CER)
+
+|   testset   | CER(%) |
+|:-----------:|:------:|
+|     dev     |  4.97  |
+|    test     |  5.37  |
\ No newline at end of file
diff --git a/examples/aishell/transformer/run.sh b/examples/aishell/transformer/run.sh
index a5ff7ff..98c2829 100755
--- a/examples/aishell/transformer/run.sh
+++ b/examples/aishell/transformer/run.sh
@@ -17,7 +17,7 @@
 inference_device="cuda" #"cpu"
 inference_checkpoint="model.pt.avg10"
 inference_scp="wav.scp"
-inference_batch_size=32
+inference_batch_size=1
 
 # data
 raw_data=../raw_data
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
index 2a83509..cbc9d8b 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
@@ -8,7 +8,7 @@
 model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                   vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                   punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
-                  # spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
+                  spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
                   )
 
 res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
diff --git a/examples/industrial_data_pretraining/sense_voice/demo.py b/examples/industrial_data_pretraining/sense_voice/demo.py
new file mode 100644
index 0000000..506b069
--- /dev/null
+++ b/examples/industrial_data_pretraining/sense_voice/demo.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoice",
+                  vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+				  vad_kwargs={"max_single_segment_time": 30000},
+                  )
+task = "ASR"
+language = None
+input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
+res = model.generate(task=task, language=language, input=input_wav, batch_size_s=0,)
+print(res)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 2822fc7..d6552a6 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -73,18 +73,24 @@
 
         speech = speech.to(device=kwargs["device"])[0, :, :]
         speech_lengths = speech_lengths.to(device=kwargs["device"])
-
+        
+        task = kwargs.get("task", "ASR")
+        if isinstance(task, str):
+            task = [task]
+        task = "".join([f"<|{x}|>" for x in task])
+        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
         language = kwargs.get("language", None)
-        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
-        # # detect the spoken language
-        # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt)
-        # print(f"Detected language: {max(probs, key=probs.get)}")
-        # language = max(probs, key=probs.get)
-        # language = language if kwargs.get("language", None) is None else kwargs.get("language")
+        language = None if language == "auto" else language
+        # if language is None:
+        #     # detect the spoken language
+        #     _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt)
+        #     print(f"Detected language: {max(probs, key=probs.get)}")
+        #     language = max(probs, key=probs.get)
+        #     language = language if kwargs.get("language", None) is None else kwargs.get("language")
         
         # decode the audio
-        prompt = ""
-        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
+        
+        # initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
         options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
         result = whisper.decode(self.model, speech, options)
 
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index 2822fc7..0e8f09b 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -1,97 +1,316 @@
+import base64
+import gzip
 from dataclasses import dataclass
-from typing import Dict
-from typing import Iterable, Optional
-import time
+from typing import Dict, Iterable, Optional
+
 import numpy as np
 import torch
 import torch.nn.functional as F
-from torch import Tensor
-from torch import nn
-from . import whisper_lib as whisper
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from torch import Tensor, nn
 
-from funasr.register import tables
+from .decoding import decode as decode_function
+from .decoding import detect_language as detect_language_function
+from .transcribe import transcribe as transcribe_function
 
 
-@tables.register("model_classes", "SenseVoice")
-class SenseVoice(nn.Module):
-    def __init__(self, *args, **kwargs):
+@dataclass
+class ModelDimensions:
+    n_mels: int
+    n_audio_ctx: int
+    n_audio_state: int
+    n_audio_head: int
+    n_audio_layer: int
+    n_vocab: int
+    n_text_ctx: int
+    n_text_state: int
+    n_text_head: int
+    n_text_layer: int
+
+
+class LayerNorm(nn.LayerNorm):
+    def forward(self, x: Tensor) -> Tensor:
+        return super().forward(x.float()).type(x.dtype)
+
+
+class Linear(nn.Linear):
+    def forward(self, x: Tensor) -> Tensor:
+        return F.linear(
+            x,
+            self.weight.to(x.dtype),
+            None if self.bias is None else self.bias.to(x.dtype),
+        )
+
+
+class Conv1d(nn.Conv1d):
+    def _conv_forward(
+        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
+    ) -> Tensor:
+        return super()._conv_forward(
+            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
+        )
+
+
+def sinusoids(length, channels, max_timescale=10000):
+    """Returns sinusoids for positional embedding"""
+    assert channels % 2 == 0
+    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+
+class MultiHeadAttention(nn.Module):
+    def __init__(self, n_state: int, n_head: int):
         super().__init__()
-        hub = kwargs.get("hub", "funasr")
+        self.n_head = n_head
+        self.query = Linear(n_state, n_state)
+        self.key = Linear(n_state, n_state, bias=False)
+        self.value = Linear(n_state, n_state)
+        self.out = Linear(n_state, n_state)
 
-        dims = kwargs.get("dims", {})
-        dims = whisper.model.ModelDimensions(**dims)
-        model = whisper.model.Whisper(dims=dims)
-        
-        self.model = model
-        
-        self.encoder_output_size = self.model.dims.n_audio_state
-        
-    def forward(self, ):
-        pass
-    
-    def inference(self,
-                  data_in,
-                  data_lengths=None,
-                  key: list = None,
-                  tokenizer=None,
-                  frontend=None,
-                  **kwargs,
-                  ):
-        if kwargs.get("batch_size", 1) > 1:
-            raise NotImplementedError("batch decoding is not implemented")
+    def forward(
+        self,
+        x: Tensor,
+        xa: Optional[Tensor] = None,
+        mask: Optional[Tensor] = None,
+        kv_cache: Optional[dict] = None,
+    ):
+        q = self.query(x)
 
-        if frontend is None and not hasattr(self, "frontend"):
-            frontend_class = tables.frontend_classes.get("WhisperFrontend")
-            frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
-            self.frontend = frontend
+        if kv_cache is None or xa is None or self.key not in kv_cache:
+            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+            # otherwise, perform key/value projections for self- or cross-attention as usual.
+            k = self.key(x if xa is None else xa)
+            v = self.value(x if xa is None else xa)
         else:
-            frontend = frontend if frontend is not None else self.frontend
+            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
+            k = kv_cache[self.key]
+            v = kv_cache[self.value]
 
-        meta_data = {}
-        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
-            speech, speech_lengths = data_in, data_lengths
-            if len(speech.shape) < 3:
-                speech = speech[None, :, :]
-            if speech_lengths is None:
-                speech_lengths = speech.shape[1]
-        else:
-            # extract fbank feats
-            time1 = time.perf_counter()
-            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000),
-                                                            data_type=kwargs.get("data_type", "sound"),
-                                                            tokenizer=tokenizer)
-            time2 = time.perf_counter()
-            meta_data["load_data"] = f"{time2 - time1:0.3f}"
-            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
-                                                   frontend=frontend)
-            time3 = time.perf_counter()
-            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
-            lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
-            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
+        wv, qk = self.qkv_attention(q, k, v, mask)
+        return self.out(wv), qk
 
-        speech = speech.to(device=kwargs["device"])[0, :, :]
-        speech_lengths = speech_lengths.to(device=kwargs["device"])
+    def qkv_attention(
+        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
+    ):
+        n_batch, n_ctx, n_state = q.shape
+        scale = (n_state // self.n_head) ** -0.25
+        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
+        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
+        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
 
-        language = kwargs.get("language", None)
-        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
-        # # detect the spoken language
-        # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt)
-        # print(f"Detected language: {max(probs, key=probs.get)}")
-        # language = max(probs, key=probs.get)
-        # language = language if kwargs.get("language", None) is None else kwargs.get("language")
-        
-        # decode the audio
-        prompt = ""
-        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
-        options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
-        result = whisper.decode(self.model, speech, options)
+        qk = q @ k
+        if mask is not None:
+            qk = qk + mask[:n_ctx, :n_ctx]
+        qk = qk.float()
 
-        results = []
-        result_i = {"key": key[0], "text": result.text}
+        w = F.softmax(qk, dim=-1).to(q.dtype)
+        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
 
-        results.append(result_i)
-    
-        return results, meta_data
-    
\ No newline at end of file
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
+        super().__init__()
+
+        self.attn = MultiHeadAttention(n_state, n_head)
+        self.attn_ln = LayerNorm(n_state)
+
+        self.cross_attn = (
+            MultiHeadAttention(n_state, n_head) if cross_attention else None
+        )
+        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+
+        n_mlp = n_state * 4
+        self.mlp = nn.Sequential(
+            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
+        )
+        self.mlp_ln = LayerNorm(n_state)
+
+    def forward(
+        self,
+        x: Tensor,
+        xa: Optional[Tensor] = None,
+        mask: Optional[Tensor] = None,
+        kv_cache: Optional[dict] = None,
+    ):
+        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
+        if self.cross_attn:
+            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
+        x = x + self.mlp(self.mlp_ln(x))
+        return x
+
+
+class AudioEncoder(nn.Module):
+    def __init__(
+        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+    ):
+        super().__init__()
+        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1)
+        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
+        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
+
+        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
+        )
+        self.ln_post = LayerNorm(n_state)
+
+    def forward(self, x: Tensor):
+        """
+        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
+            the mel spectrogram of the audio
+        """
+        x = F.gelu(self.conv1(x))
+        x = F.gelu(self.conv2(x))
+        x = x.permute(0, 2, 1)
+
+        # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
+        # x = (x + self.positional_embedding).to(x.dtype)
+        x = (x + self.positional_embedding[: x.size(1), :]).to(x.dtype)
+
+
+        for block in self.blocks:
+            x = block(x)
+
+        x = self.ln_post(x)
+        return x
+
+
+class TextDecoder(nn.Module):
+    def __init__(
+        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+    ):
+        super().__init__()
+
+        self.token_embedding = nn.Embedding(n_vocab, n_state)
+        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+
+        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+            [
+                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
+                for _ in range(n_layer)
+            ]
+        )
+        self.ln = LayerNorm(n_state)
+
+        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
+        self.register_buffer("mask", mask, persistent=False)
+
+    def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
+        """
+        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
+            the text tokens
+        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
+            the encoded audio features to be attended on
+        """
+        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+        x = (
+            self.token_embedding(x)
+            + self.positional_embedding[offset : offset + x.shape[-1]]
+        )
+        x = x.to(xa.dtype)
+
+        for block in self.blocks:
+            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
+
+        x = self.ln(x)
+        logits = (
+            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+        ).float()
+
+        return logits
+
+
+class Whisper(nn.Module):
+    def __init__(self, dims: ModelDimensions):
+        super().__init__()
+        self.dims = dims
+        self.encoder = AudioEncoder(
+            self.dims.n_mels,
+            self.dims.n_audio_ctx,
+            self.dims.n_audio_state,
+            self.dims.n_audio_head,
+            self.dims.n_audio_layer,
+        )
+        self.decoder = TextDecoder(
+            self.dims.n_vocab,
+            self.dims.n_text_ctx,
+            self.dims.n_text_state,
+            self.dims.n_text_head,
+            self.dims.n_text_layer,
+        )
+        # use the last half among the decoder layers for time alignment by default;
+        # to use a specific set of heads, see `set_alignment_heads()` below.
+        all_heads = torch.zeros(
+            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
+        )
+        all_heads[self.dims.n_text_layer // 2 :] = True
+        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+
+    def set_alignment_heads(self, dump: bytes):
+        array = np.frombuffer(
+            gzip.decompress(base64.b85decode(dump)), dtype=bool
+        ).copy()
+        mask = torch.from_numpy(array).reshape(
+            self.dims.n_text_layer, self.dims.n_text_head
+        )
+        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
+
+    def embed_audio(self, mel: torch.Tensor):
+        return self.encoder(mel)
+
+    def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
+        return self.decoder(tokens, audio_features)
+
+    def forward(
+        self, mel: torch.Tensor, tokens: torch.Tensor
+    ) -> Dict[str, torch.Tensor]:
+        return self.decoder(tokens, self.encoder(mel))
+
+    @property
+    def device(self):
+        return next(self.parameters()).device
+
+    @property
+    def is_multilingual(self):
+        return self.dims.n_vocab >= 51865
+
+    @property
+    def num_languages(self):
+        return self.dims.n_vocab - 51765 - int(self.is_multilingual)
+
+    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
+        """
+        The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
+        tensors calculated for the previous positions. This method returns a dictionary that stores
+        all caches, and the necessary hooks for the key and value projection modules that save the
+        intermediate tensors to be reused during later calculations.
+
+        Returns
+        -------
+        cache : Dict[nn.Module, torch.Tensor]
+            A dictionary object mapping the key/value projection modules to its cache
+        hooks : List[RemovableHandle]
+            List of PyTorch RemovableHandle objects to stop the hooks to be called
+        """
+        cache = {**cache} if cache is not None else {}
+        hooks = []
+
+        def save_to_cache(module, _, output):
+            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
+                # save as-is, for the first token or cross attention
+                cache[module] = output
+            else:
+                cache[module] = torch.cat([cache[module], output], dim=1).detach()
+            return cache[module]
+
+        def install_hooks(layer: nn.Module):
+            if isinstance(layer, MultiHeadAttention):
+                hooks.append(layer.key.register_forward_hook(save_to_cache))
+                hooks.append(layer.value.register_forward_hook(save_to_cache))
+
+        self.decoder.apply(install_hooks)
+        return cache, hooks
+
+    detect_language = detect_language_function
+    transcribe = transcribe_function
+    decode = decode_function
\ No newline at end of file

--
Gitblit v1.9.1