zhifu gao
2024-03-30 702b9b540c3c1524748cd975a10ce33f0fa53912
sense voice (#1568)

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* train

* whisper_lib for sense voice

* aishell recipe

* sense voice
6个文件已修改
2个文件已添加
106652 ■■■■ 已修改文件
funasr/__init__.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/model.py 97 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/template.yaml 45 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/assets/multilingual.tiktoken 106011 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/decoding.py 33 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/model.py 379 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/whisper_lib/tokenizer.py 79 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/__init__.py
@@ -15,7 +15,12 @@
def import_submodules(package, recursive=True):
    if isinstance(package, str):
        package = importlib.import_module(package)
        try:
            package = importlib.import_module(package)
        except Exception as e:
            # 如果想要看到导入错误的具体信息,可以取消注释下面的行
            # print(f"Failed to import {name}: {e}")
            pass
    results = {}
    for loader, name, is_pkg in pkgutil.walk_packages(package.__path__, package.__name__ + '.'):
        try:
funasr/models/sense_voice/model.py
New file
@@ -0,0 +1,97 @@
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from . import whisper_lib as whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.register import tables
@tables.register("model_classes", "SenseVoice")
class SenseVoice(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        hub = kwargs.get("hub", "funasr")
        dims = kwargs.get("dims", {})
        dims = whisper.model.ModelDimensions(**dims)
        model = whisper.model.Whisper(dims=dims)
        self.model = model
        self.encoder_output_size = self.model.dims.n_audio_state
    def forward(self, ):
        pass
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  **kwargs,
                  ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        if frontend is None and not hasattr(self, "frontend"):
            frontend_class = tables.frontend_classes.get("WhisperFrontend")
            frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
            self.frontend = frontend
        else:
            frontend = frontend if frontend is not None else self.frontend
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
            lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
        speech = speech.to(device=kwargs["device"])[0, :, :]
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        language = kwargs.get("language", None)
        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
        # # detect the spoken language
        # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt)
        # print(f"Detected language: {max(probs, key=probs.get)}")
        # language = max(probs, key=probs.get)
        # language = language if kwargs.get("language", None) is None else kwargs.get("language")
        # decode the audio
        prompt = ""
        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
        options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
        result = whisper.decode(self.model, speech, options)
        results = []
        result_i = {"key": key[0], "text": result.text}
        results.append(result_i)
        return results, meta_data
funasr/models/sense_voice/template.yaml
New file
@@ -0,0 +1,45 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: SenseVoice
model_conf:
    lsm_weight: 0.1
    length_normalized_loss: true
    hub: funasr
# only use for hub == funasr,
#  if hub == openai, dims is automaticall download
dims:
    n_mels: 128
    n_vocab: 51866
    n_audio_ctx: 1500
    n_audio_state: 1280
    n_audio_head: 20
    n_audio_layer: 32
    n_text_ctx: 448
    n_text_state: 1280
    n_text_head: 20
    n_text_layer: 32
# frontend related
frontend: WhisperFrontend
frontend_conf:
    fs: 16000
    n_mels: ${dims.n_mels}
    do_pad_trim: true
tokenizer: WhisperTokenizer
tokenizer_conf:
  language: null
  task: transcribe
  is_multilingual: true
  num_languages: 100
scope_map: [none, "model."]
funasr/models/sense_voice/whisper_lib/assets/multilingual.tiktoken
Diff too large
funasr/models/sense_voice/whisper_lib/decoding.py
@@ -17,7 +17,7 @@
@torch.no_grad()
def detect_language(
    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
    model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None, initial_prompt = None, x = None,
) -> Tuple[Tensor, List[dict]]:
    """
    Detect the spoken language in the audio, and return them as list of strings, along with the ids
@@ -48,12 +48,16 @@
        mel = mel.unsqueeze(0)
    # skip encoder forward pass if already-encoded audio features were given
    if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
    # FIX(funasr): sense vocie
    if mel.shape[-1] != model.dims.n_audio_state:
        mel = model.encoder(mel)
    # forward pass using a single token, startoftranscript
    n_audio = mel.shape[0]
    x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
    # FIX(funasr): sense vocie
    # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]
    if x is None:
        x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(mel.device)  # [n_audio, 1]
    logits = model.logits(x, mel)[:, 0]
    # collect detected languages; suppress all non-language tokens
@@ -112,6 +116,9 @@
    # implementation details
    fp16: bool = True  # use fp16 for most of the calculation
    # FIX(funasr): sense vocie
    initial_prompt: str = None
@dataclass(frozen=True)
@@ -609,6 +616,12 @@
                + prompt_tokens[-(self.n_ctx // 2 - 1) :]
                + tokens
            )
        #FIX(gzf): sense vocie
        if initial_prompt := self.options.initial_prompt:
            tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
            if self.options.language is None:
                tokens += [0]
        return tuple(tokens)
@@ -669,11 +682,21 @@
        if self.options.language is None or self.options.task == "lang_id":
            lang_tokens, lang_probs = self.model.detect_language(
                audio_features, self.tokenizer
                audio_features, self.tokenizer, x=tokens
            )
            languages = [max(probs, key=probs.get) for probs in lang_probs]
            # FIX(funasr): sense vocie
            # if self.options.language is None:
                # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
            if self.options.language is None:
                tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
                # tokens[:, self.sot_index + 1] = lang_tokens  # write language tokens
                languages = "".join([f"<|{language}|>" for language in languages])
                n_audio = audio_features.shape[0]
                lang_tokens = torch.tensor([self.tokenizer.encode(languages, allowed_special="all")] * n_audio).to(
                    audio_features.device)  # [n_audio, 1]
                tokens[:, -1:] = lang_tokens[:, :]
                languages = [languages]
        return languages, lang_probs
funasr/models/sense_voice/whisper_lib/model.py
@@ -1,314 +1,97 @@
import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
from typing import Dict
from typing import Iterable, Optional
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch import Tensor
from torch import nn
from . import whisper_lib as whisper
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function
from funasr.register import tables
@dataclass
class ModelDimensions:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int
class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )
class Conv1d(nn.Conv1d):
    def _conv_forward(
        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
    ) -> Tensor:
        return super()._conv_forward(
            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
        )
def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
@tables.register("model_classes", "SenseVoice")
class SenseVoice(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)
        hub = kwargs.get("hub", "funasr")
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        q = self.query(x)
        dims = kwargs.get("dims", {})
        dims = whisper.model.ModelDimensions(**dims)
        model = whisper.model.Whisper(dims=dims)
        self.model = model
        self.encoder_output_size = self.model.dims.n_audio_state
    def forward(self, ):
        pass
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  **kwargs,
                  ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        if kv_cache is None or xa is None or self.key not in kv_cache:
            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
            # otherwise, perform key/value projections for self- or cross-attention as usual.
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        if frontend is None and not hasattr(self, "frontend"):
            frontend_class = tables.frontend_classes.get("WhisperFrontend")
            frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
            self.frontend = frontend
        else:
            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
            k = kv_cache[self.key]
            v = kv_cache[self.value]
            frontend = frontend if frontend is not None else self.frontend
        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
            lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        speech = speech.to(device=kwargs["device"])[0, :, :]
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
        qk = qk.float()
        language = kwargs.get("language", None)
        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
        # # detect the spoken language
        # _, probs = self.model.detect_language(speech, initial_prompt=initial_prompt)
        # print(f"Detected language: {max(probs, key=probs.get)}")
        # language = max(probs, key=probs.get)
        # language = language if kwargs.get("language", None) is None else kwargs.get("language")
        # decode the audio
        prompt = ""
        initial_prompt = kwargs.get("initial_prompt", "<|startoftranscript|><|ASR|>")
        options = whisper.DecodingOptions(language=language, fp16=False, without_timestamps=True, initial_prompt=initial_prompt)
        result = whisper.decode(self.model, speech, options)
        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
        results = []
        result_i = {"key": key[0], "text": result.text}
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()
        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)
        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x
class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)
    def forward(self, x: Tensor):
        """
        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            the mel spectrogram of the audio
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)
        assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
        x = (x + self.positional_embedding).to(x.dtype)
        for block in self.blocks:
            x = block(x)
        x = self.ln_post(x)
        return x
class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)
        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)
    def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
        """
        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            the text tokens
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            the encoded audio features to be attended on
        """
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)
        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
        x = self.ln(x)
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()
        return logits
class Whisper(nn.Module):
    def __init__(self, dims: ModelDimensions):
        super().__init__()
        self.dims = dims
        self.encoder = AudioEncoder(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )
        self.decoder = TextDecoder(
            self.dims.n_vocab,
            self.dims.n_text_ctx,
            self.dims.n_text_state,
            self.dims.n_text_head,
            self.dims.n_text_layer,
        )
        # use the last half among the decoder layers for time alignment by default;
        # to use a specific set of heads, see `set_alignment_heads()` below.
        all_heads = torch.zeros(
            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
        )
        all_heads[self.dims.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
    def set_alignment_heads(self, dump: bytes):
        array = np.frombuffer(
            gzip.decompress(base64.b85decode(dump)), dtype=bool
        ).copy()
        mask = torch.from_numpy(array).reshape(
            self.dims.n_text_layer, self.dims.n_text_head
        )
        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
    def embed_audio(self, mel: torch.Tensor):
        return self.encoder(mel)
    def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
        return self.decoder(tokens, audio_features)
    def forward(
        self, mel: torch.Tensor, tokens: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        return self.decoder(tokens, self.encoder(mel))
    @property
    def device(self):
        return next(self.parameters()).device
    @property
    def is_multilingual(self):
        return self.dims.n_vocab >= 51865
    @property
    def num_languages(self):
        return self.dims.n_vocab - 51765 - int(self.is_multilingual)
    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        """
        The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
        tensors calculated for the previous positions. This method returns a dictionary that stores
        all caches, and the necessary hooks for the key and value projection modules that save the
        intermediate tensors to be reused during later calculations.
        Returns
        -------
        cache : Dict[nn.Module, torch.Tensor]
            A dictionary object mapping the key/value projection modules to its cache
        hooks : List[RemovableHandle]
            List of PyTorch RemovableHandle objects to stop the hooks to be called
        """
        cache = {**cache} if cache is not None else {}
        hooks = []
        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
                # save as-is, for the first token or cross attention
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]
        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))
        self.decoder.apply(install_hooks)
        return cache, hooks
    detect_language = detect_language_function
    transcribe = transcribe_function
    decode = decode_function
        results.append(result_i)
        return results, meta_data
funasr/models/sense_voice/whisper_lib/tokenizer.py
@@ -7,6 +7,7 @@
import tiktoken
# FIX(funasr): sense vocie
LANGUAGES = {
    "en": "english",
    "zh": "chinese",
@@ -108,6 +109,11 @@
    "jw": "javanese",
    "su": "sundanese",
    "yue": "cantonese",
    "minnan": "minnan",
    "wuyu": "wuyu",
    "dialect": "dialect",
    "zh/en": "zh/en",
    "en/zh": "en/zh",
}
# language code lookup by name, with a few language aliases
@@ -125,6 +131,28 @@
    "sinhalese": "si",
    "castilian": "es",
    "mandarin": "zh",
}
# FIX(funasr): sense vocie
AUDIO_EVENT = {
    "ASR": "ASR",
    "AED": "AED",
    "SER": "SER",
    "Speech": "Speech",
    "/Speech": "/Speech",
    "BGM": "BGM",
    "/BGM": "/BGM",
    "Laughter": "Laughter",
    "/Laughter": "/Laughter",
    "Applause": "Applause",
    "/Applause": "/Applause",
}
EMOTION = {
    "HAPPY": "HAPPY",
    "SAD": "SAD",
    "ANGRY": "ANGRY",
    "NEUTRAL": "NEUTRAL",
}
@@ -171,6 +199,9 @@
        This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
        """
        return self.encoding.decode(token_ids, **kwargs)
    def get_vocab_size(self) -> int:
        return self.encoding.n_vocab
    @cached_property
    def eot(self) -> int:
@@ -186,6 +217,10 @@
    @cached_property
    def sot(self) -> int:
        return self.special_tokens["<|startoftranscript|>"]
    @cached_property
    def sot_sense(self) -> int:
        return self.special_tokens["<|startoftranscript|>"]
    @cached_property
@@ -337,18 +372,35 @@
    n_vocab = len(ranks)
    special_tokens = {}
    specials = [
        "<|endoftext|>",
        "<|startoftranscript|>",
        *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
        "<|translate|>",
        "<|transcribe|>",
        "<|startoflm|>",
        "<|startofprev|>",
        "<|nospeech|>",
        "<|notimestamps|>",
        *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
    ]
    if False: #name == "gpt2" or name == "multilingual":
        specials = [
            "<|endoftext|>",
            "<|startoftranscript|>",
            *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
            "<|translate|>",
            "<|transcribe|>",
            "<|startoflm|>",
            "<|startofprev|>",
            "<|nospeech|>",
            "<|notimestamps|>",
            *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
        ]
    else:
        specials = [
            "<|endoftext|>",
            "<|startoftranscript|>",
            *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
            *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
            *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
            "<|translate|>",
            "<|transcribe|>",
            "<|startoflm|>",
            "<|startofprev|>",
            "<|nospeech|>",
            "<|notimestamps|>",
            *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 51)],
            *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
        ]
    for token in specials:
        special_tokens[token] = n_vocab
@@ -370,6 +422,7 @@
    num_languages: int = 99,
    language: Optional[str] = None,
    task: Optional[str] = None,  # Literal["transcribe", "translate", None]
    encoding_path: Optional[str] = None,
) -> Tokenizer:
    if language is not None:
        language = language.lower()
@@ -387,6 +440,8 @@
        encoding_name = "gpt2"
        language = None
        task = None
    if encoding_path is not None:
        encoding_name = encoding_path
    encoding = get_encoding(name=encoding_name, num_languages=num_languages)
setup.py
@@ -40,6 +40,7 @@
        "hydra-core>=1.3.2",
        "tensorboardX",
        "rotary_embedding_torch",
        "openai-whisper",
    ],
    # train: The modules invoked when training only.
    "train": [