From 48693b45c021a842ea964c9dc99479b61eac062f Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 02 四月 2024 10:33:27 +0800
Subject: [PATCH] Dev gzf new (#1574)
---
examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml | 2
examples/industrial_data_pretraining/paraformer-zh-spk/demo.py | 2
examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml | 8
examples/aishell/paraformer/README.md | 24 ++
examples/aishell/transformer/README.md | 16 +
examples/aishell/branchformer/run.sh | 4
examples/aishell/conformer/README.md | 16 +
examples/aishell/e_branchformer/run.sh | 4
examples/industrial_data_pretraining/sense_voice/demo.py | 16 +
funasr/models/sense_voice/model.py | 24 +-
examples/aishell/branchformer/README.md | 14 +
examples/aishell/e_branchformer/README.md | 14 +
funasr/models/sense_voice/whisper_lib/model.py | 381 +++++++++++++++++++++++++++++++++---------
examples/aishell/transformer/run.sh | 2
14 files changed, 428 insertions(+), 99 deletions(-)
diff --git a/examples/aishell/branchformer/README.md b/examples/aishell/branchformer/README.md
new file mode 100644
index 0000000..930e58f
--- /dev/null
+++ b/examples/aishell/branchformer/README.md
@@ -0,0 +1,14 @@
+# Branchformer Result
+
+## Training Config
+- Feature info: using raw speech, extracting 80 dims fbank online, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train info: lr 0.001, batch_size 10000, 4 gpu(Tesla V100), acc_grad 1, 180 epochs
+- Train config: conf/train_asr_branchformer.yaml
+- LM config: LM was not used
+
+## Results (CER)
+
+| testset | CER(%) |
+|:-----------:|:-------:|
+| dev | 4.15 |
+| test | 4.51 |
\ No newline at end of file
diff --git a/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml b/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml
index aefd2b9..acb7946 100644
--- a/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml
+++ b/examples/aishell/branchformer/conf/branchformer_12e_6d_2048_256.yaml
@@ -79,8 +79,9 @@
train_conf:
accum_grad: 1
grad_clip: 5
- max_epoch: 150
+ max_epoch: 180
keep_nbest_models: 10
+ avg_keep_nbest_models_type: acc
log_interval: 50
optim: adam
@@ -96,7 +97,7 @@
index_ds: IndexDSJsonl
batch_sampler: EspnetStyleBatchSampler
batch_type: length # example or length
- batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ batch_size: 10000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
buffer_size: 1024
shuffle: True
@@ -116,3 +117,6 @@
reduce: true
ignore_nan_grad: true
normalize: null
+
+beam_size: 10
+decoding_ctc_weight: 0.4
\ No newline at end of file
diff --git a/examples/aishell/branchformer/run.sh b/examples/aishell/branchformer/run.sh
index f7dda1c..918aa9b 100755
--- a/examples/aishell/branchformer/run.sh
+++ b/examples/aishell/branchformer/run.sh
@@ -1,7 +1,7 @@
#!/usr/bin/env bash
-CUDA_VISIBLE_DEVICES="0,1"
+CUDA_VISIBLE_DEVICES="0,1,2,3"
# general configuration
feats_dir="../DATA" #feature output dictionary
@@ -17,7 +17,7 @@
inference_device="cuda" #"cpu"
inference_checkpoint="model.pt.avg10"
inference_scp="wav.scp"
-inference_batch_size=32
+inference_batch_size=1
# data
raw_data=../raw_data
diff --git a/examples/aishell/conformer/README.md b/examples/aishell/conformer/README.md
new file mode 100644
index 0000000..003cbac
--- /dev/null
+++ b/examples/aishell/conformer/README.md
@@ -0,0 +1,16 @@
+
+# Conformer Result
+
+## Training Config
+- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
+- Train config: conf/train_asr_transformer.yaml
+- LM config: LM was not used
+- Model size: 46M
+
+## Results (CER)
+
+| testset | CER(%) |
+|:-----------:|:-------:|
+| dev | 4.42 |
+| test | 4.87 |
\ No newline at end of file
diff --git a/examples/aishell/e_branchformer/README.md b/examples/aishell/e_branchformer/README.md
new file mode 100644
index 0000000..ac0aabb
--- /dev/null
+++ b/examples/aishell/e_branchformer/README.md
@@ -0,0 +1,14 @@
+# E-Branchformer Result
+
+## Training Config
+- Feature info: using raw speech, extracting 80 dims fbank online, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train info: lr 0.001, batch_size 10000, 4 gpu(Tesla V100), acc_grad 1, 180 epochs
+- Train config: conf/train_asr_e_branchformer.yaml
+- LM config: LM was not used
+
+## Results (CER)
+
+| testset | CER(%) |
+|:-----------:|:-------:|
+| dev | 4.10 |
+| test | 4.52 |
\ No newline at end of file
diff --git a/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml b/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml
index 28d8e94..6438ae1 100644
--- a/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml
+++ b/examples/aishell/e_branchformer/conf/e_branchformer_12e_6d_2048_256.yaml
@@ -96,7 +96,7 @@
index_ds: IndexDSJsonl
batch_sampler: EspnetStyleBatchSampler
batch_type: length # example or length
- batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ batch_size: 10000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
buffer_size: 1024
shuffle: True
diff --git a/examples/aishell/e_branchformer/run.sh b/examples/aishell/e_branchformer/run.sh
index bc78b5f..be18599 100755
--- a/examples/aishell/e_branchformer/run.sh
+++ b/examples/aishell/e_branchformer/run.sh
@@ -1,7 +1,7 @@
#!/usr/bin/env bash
-CUDA_VISIBLE_DEVICES="0,1"
+CUDA_VISIBLE_DEVICES="0,1,2,3"
# general configuration
feats_dir="../DATA" #feature output dictionary
@@ -17,7 +17,7 @@
inference_device="cuda" #"cpu"
inference_checkpoint="model.pt.avg10"
inference_scp="wav.scp"
-inference_batch_size=32
+inference_batch_size=1
# data
raw_data=../raw_data
diff --git a/examples/aishell/paraformer/README.md b/examples/aishell/paraformer/README.md
new file mode 100644
index 0000000..c0385db
--- /dev/null
+++ b/examples/aishell/paraformer/README.md
@@ -0,0 +1,24 @@
+# Paraformer
+pretrained model in [ModelScope](https://www.modelscope.cn/home)锛歔speech_paraformer_asr_nat-aishell1-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary)
+
+## Training Config
+- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
+- Train config: conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
+- LM config: LM was not used
+
+## Results (CER)
+
+- Decode config: conf/decode_asr_transformer_noctc_1best.yaml (ctc weight:0.0)
+
+| testset | CER(%) |
+|:-----------:|:-------:|
+| dev | 4.66 |
+| test | 5.11 |
+
+- Decode config: conf/decode_asr_transformer.yaml (ctc weight:0.5)
+
+| testset | CER(%) |
+|:-----------:|:-------:|
+| dev | 4.52 |
+| test | 4.94 |
\ No newline at end of file
diff --git a/examples/aishell/transformer/README.md b/examples/aishell/transformer/README.md
new file mode 100644
index 0000000..2435b55
--- /dev/null
+++ b/examples/aishell/transformer/README.md
@@ -0,0 +1,16 @@
+
+# Conformer Result
+
+## Training Config
+- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
+- Train config: conf/train_asr_transformer.yaml
+- LM config: LM was not used
+- Model size: 46M
+
+## Results (CER)
+
+| testset | CER(%) |
+|:-----------:|:------:|
+| dev | 4.97 |
+| test | 5.37 |
\ No newline at end of file
diff --git a/examples/aishell/transformer/run.sh b/examples/aishell/transformer/run.sh
index a5ff7ff..98c2829 100755
--- a/examples/aishell/transformer/run.sh
+++ b/examples/aishell/transformer/run.sh
@@ -17,7 +17,7 @@
inference_device="cuda" #"cpu"
inference_checkpoint="model.pt.avg10"
inference_scp="wav.scp"
-inference_batch_size=32
+inference_batch_size=1
# data
raw_data=../raw_data
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
index 2a83509..cbc9d8b 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
@@ -8,7 +8,7 @@
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- # spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
+ spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
)
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
diff --git a/examples/industrial_data_pretraining/sense_voice/demo.py b/examples/industrial_data_pretraining/sense_voice/demo.py
new file mode 100644
index 0000000..506b069
--- /dev/null
+++ b/examples/industrial_data_pretraining/sense_voice/demo.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/SenseVoice",
+ vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+ vad_kwargs={"max_single_segment_time": 30000},
+ )
+task = "ASR"
+language = None
+input_wav = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
+res = model.generate(task=task, language=language, input=input_wav, batch_size_s=0,)
+print(res)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 2822fc7..d6552a6 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -73,18 +73,24 @@
speech = speech.to(device=kwargs["device"])[0, :, :]
speech_lengths = speech_lengths.to(device=kwargs["device"])
-
+
+ task = kwargs.get("task", "ASR")
+ if isinstance(task, str):
+ task = [task]
+ task = "".join([f"<|{x}|>" for x in task])
+ initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
language = kwargs.get("language", None)
- initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
- # # detect the spoken language
- # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt)
- # print(f"Detected language: {max(probs, key=probs.get)}")
- # language = max(probs, key=probs.get)
- # language = language if kwargs.get("language", None) is None else kwargs.get("language")
+ language = None if language == "auto" else language
+ # if language is None:
+ # # detect the spoken language
+ # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt)
+ # print(f"Detected language: {max(probs, key=probs.get)}")
+ # language = max(probs, key=probs.get)
+ # language = language if kwargs.get("language", None) is None else kwargs.get("language")
# decode the audio
- prompt = ""
- initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
+
+ # initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
result = whisper.decode(self.model, speech, options)
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index 2822fc7..0e8f09b 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -1,97 +1,316 @@
+import base64
+import gzip
from dataclasses import dataclass
-from typing import Dict
-from typing import Iterable, Optional
-import time
+from typing import Dict, Iterable, Optional
+
import numpy as np
import torch
import torch.nn.functional as F
-from torch import Tensor
-from torch import nn
-from . import whisper_lib as whisper
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from torch import Tensor, nn
-from funasr.register import tables
+from .decoding import decode as decode_function
+from .decoding import detect_language as detect_language_function
+from .transcribe import transcribe as transcribe_function
-@tables.register("model_classes", "SenseVoice")
-class SenseVoice(nn.Module):
- def __init__(self, *args, **kwargs):
+@dataclass
+class ModelDimensions:
+ n_mels: int
+ n_audio_ctx: int
+ n_audio_state: int
+ n_audio_head: int
+ n_audio_layer: int
+ n_vocab: int
+ n_text_ctx: int
+ n_text_state: int
+ n_text_head: int
+ n_text_layer: int
+
+
+class LayerNorm(nn.LayerNorm):
+ def forward(self, x: Tensor) -> Tensor:
+ return super().forward(x.float()).type(x.dtype)
+
+
+class Linear(nn.Linear):
+ def forward(self, x: Tensor) -> Tensor:
+ return F.linear(
+ x,
+ self.weight.to(x.dtype),
+ None if self.bias is None else self.bias.to(x.dtype),
+ )
+
+
+class Conv1d(nn.Conv1d):
+ def _conv_forward(
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
+ ) -> Tensor:
+ return super()._conv_forward(
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
+ )
+
+
+def sinusoids(length, channels, max_timescale=10000):
+ """Returns sinusoids for positional embedding"""
+ assert channels % 2 == 0
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, n_state: int, n_head: int):
super().__init__()
- hub = kwargs.get("hub", "funasr")
+ self.n_head = n_head
+ self.query = Linear(n_state, n_state)
+ self.key = Linear(n_state, n_state, bias=False)
+ self.value = Linear(n_state, n_state)
+ self.out = Linear(n_state, n_state)
- dims = kwargs.get("dims", {})
- dims = whisper.model.ModelDimensions(**dims)
- model = whisper.model.Whisper(dims=dims)
-
- self.model = model
-
- self.encoder_output_size = self.model.dims.n_audio_state
-
- def forward(self, ):
- pass
-
- def inference(self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
- if kwargs.get("batch_size", 1) > 1:
- raise NotImplementedError("batch decoding is not implemented")
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ ):
+ q = self.query(x)
- if frontend is None and not hasattr(self, "frontend"):
- frontend_class = tables.frontend_classes.get("WhisperFrontend")
- frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
- self.frontend = frontend
+ if kv_cache is None or xa is None or self.key not in kv_cache:
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
+ k = self.key(x if xa is None else xa)
+ v = self.value(x if xa is None else xa)
else:
- frontend = frontend if frontend is not None else self.frontend
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
+ k = kv_cache[self.key]
+ v = kv_cache[self.value]
- meta_data = {}
- if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer)
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
- frontend=frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
- lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
+ wv, qk = self.qkv_attention(q, k, v, mask)
+ return self.out(wv), qk
- speech = speech.to(device=kwargs["device"])[0, :, :]
- speech_lengths = speech_lengths.to(device=kwargs["device"])
+ def qkv_attention(
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
+ ):
+ n_batch, n_ctx, n_state = q.shape
+ scale = (n_state // self.n_head) ** -0.25
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
- language = kwargs.get("language", None)
- initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
- # # detect the spoken language
- # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt)
- # print(f"Detected language: {max(probs, key=probs.get)}")
- # language = max(probs, key=probs.get)
- # language = language if kwargs.get("language", None) is None else kwargs.get("language")
-
- # decode the audio
- prompt = ""
- initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
- options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
- result = whisper.decode(self.model, speech, options)
+ qk = q @ k
+ if mask is not None:
+ qk = qk + mask[:n_ctx, :n_ctx]
+ qk = qk.float()
- results = []
- result_i = {"key": key[0], "text": result.text}
+ w = F.softmax(qk, dim=-1).to(q.dtype)
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
- results.append(result_i)
-
- return results, meta_data
-
\ No newline at end of file
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
+ super().__init__()
+
+ self.attn = MultiHeadAttention(n_state, n_head)
+ self.attn_ln = LayerNorm(n_state)
+
+ self.cross_attn = (
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
+ )
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+
+ n_mlp = n_state * 4
+ self.mlp = nn.Sequential(
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
+ )
+ self.mlp_ln = LayerNorm(n_state)
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ ):
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
+ if self.cross_attn:
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
+ x = x + self.mlp(self.mlp_ln(x))
+ return x
+
+
+class AudioEncoder(nn.Module):
+ def __init__(
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+ ):
+ super().__init__()
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1)
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
+
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
+ )
+ self.ln_post = LayerNorm(n_state)
+
+ def forward(self, x: Tensor):
+ """
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
+ the mel spectrogram of the audio
+ """
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+ x = x.permute(0, 2, 1)
+
+ # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
+ # x = (x + self.positional_embedding).to(x.dtype)
+ x = (x + self.positional_embedding[: x.size(1), :]).to(x.dtype)
+
+
+ for block in self.blocks:
+ x = block(x)
+
+ x = self.ln_post(x)
+ return x
+
+
+class TextDecoder(nn.Module):
+ def __init__(
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
+ ):
+ super().__init__()
+
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+ [
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
+ for _ in range(n_layer)
+ ]
+ )
+ self.ln = LayerNorm(n_state)
+
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
+ self.register_buffer("mask", mask, persistent=False)
+
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
+ """
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
+ the text tokens
+ xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
+ the encoded audio features to be attended on
+ """
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+ x = (
+ self.token_embedding(x)
+ + self.positional_embedding[offset : offset + x.shape[-1]]
+ )
+ x = x.to(xa.dtype)
+
+ for block in self.blocks:
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
+
+ x = self.ln(x)
+ logits = (
+ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+ ).float()
+
+ return logits
+
+
+class Whisper(nn.Module):
+ def __init__(self, dims: ModelDimensions):
+ super().__init__()
+ self.dims = dims
+ self.encoder = AudioEncoder(
+ self.dims.n_mels,
+ self.dims.n_audio_ctx,
+ self.dims.n_audio_state,
+ self.dims.n_audio_head,
+ self.dims.n_audio_layer,
+ )
+ self.decoder = TextDecoder(
+ self.dims.n_vocab,
+ self.dims.n_text_ctx,
+ self.dims.n_text_state,
+ self.dims.n_text_head,
+ self.dims.n_text_layer,
+ )
+ # use the last half among the decoder layers for time alignment by default;
+ # to use a specific set of heads, see `set_alignment_heads()` below.
+ all_heads = torch.zeros(
+ self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
+ )
+ all_heads[self.dims.n_text_layer // 2 :] = True
+ self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
+
+ def set_alignment_heads(self, dump: bytes):
+ array = np.frombuffer(
+ gzip.decompress(base64.b85decode(dump)), dtype=bool
+ ).copy()
+ mask = torch.from_numpy(array).reshape(
+ self.dims.n_text_layer, self.dims.n_text_head
+ )
+ self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
+
+ def embed_audio(self, mel: torch.Tensor):
+ return self.encoder(mel)
+
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
+ return self.decoder(tokens, audio_features)
+
+ def forward(
+ self, mel: torch.Tensor, tokens: torch.Tensor
+ ) -> Dict[str, torch.Tensor]:
+ return self.decoder(tokens, self.encoder(mel))
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ @property
+ def is_multilingual(self):
+ return self.dims.n_vocab >= 51865
+
+ @property
+ def num_languages(self):
+ return self.dims.n_vocab - 51765 - int(self.is_multilingual)
+
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
+ """
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
+ tensors calculated for the previous positions. This method returns a dictionary that stores
+ all caches, and the necessary hooks for the key and value projection modules that save the
+ intermediate tensors to be reused during later calculations.
+
+ Returns
+ -------
+ cache : Dict[nn.Module, torch.Tensor]
+ A dictionary object mapping the key/value projection modules to its cache
+ hooks : List[RemovableHandle]
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
+ """
+ cache = {**cache} if cache is not None else {}
+ hooks = []
+
+ def save_to_cache(module, _, output):
+ if module not in cache or output.shape[1] > self.dims.n_text_ctx:
+ # save as-is, for the first token or cross attention
+ cache[module] = output
+ else:
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
+ return cache[module]
+
+ def install_hooks(layer: nn.Module):
+ if isinstance(layer, MultiHeadAttention):
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
+
+ self.decoder.apply(install_hooks)
+ return cache, hooks
+
+ detect_language = detect_language_function
+ transcribe = transcribe_function
+ decode = decode_function
\ No newline at end of file
--
Gitblit v1.9.1