zhifu gao
2024-03-11 f2d8ded57f6403696001d39dd07a1396e5a03658
export onnx (#1455)

* qwenaudio qwenaudiochat

* qwenaudio qwenaudiochat

* whisper

* whisper

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* export onnx
13个文件已修改
1个文件已删除
12个文件已添加
1981 ■■■■■ 已修改文件
examples/industrial_data_pretraining/fsmn_vad_streaming/export.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear.yaml 94 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/export.py 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/export.sh 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py 54 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/export.py 40 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets/preprocessor.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets_qwenaudio/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets_qwenaudio/datasets.py 150 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/llm_datasets_vicuna/preprocessor.py 37 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/fsmn_vad_streaming/encoder.py 52 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/fsmn_vad_streaming/model.py 40 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/llm_asr_nar/model.py 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/cif_predictor.py 116 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/decoder.py 171 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer/model.py 77 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/qwen_audio/audio.py 446 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/qwen_audio/model.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/attention.py 58 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/encoder.py 132 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sanm/multihead_att.py 243 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/whisper/template.yaml 31 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/export_utils.py 74 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/misc.py 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/torch_function.py 80 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
New file
@@ -0,0 +1,12 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4", export_model=True)
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
New file
@@ -0,0 +1,19 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
export HYDRA_FULL_ERROR=1
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
model_revision="v2.0.4"
python funasr/bin/export.py \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu" \
++debug=false
examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear.yaml
New file
@@ -0,0 +1,94 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: LLMASR
model_conf:
    lsm_weight: 0.1     # label smoothing option
    length_normalized_loss: true
# encoder
audio_encoder: iic/Whisper-large-v2 #iic/Whisper-large-v3
audio_encoder_conf:
    hub: ms
    freeze: true
    init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
llm: Vicuna
llm_conf:
  hub: hf
  freeze: true
  init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
audio_adaptor: Linear
audio_adaptor_conf:
  downsample_rate: 5
  llm_dim: 4096
  encoder_dim: 512
# frontend related
frontend: WhisperFrontend
frontend_conf:
    fs: 16000
    whisper_model: large-v2
    do_pad_trim: true
    permute: true # true: [bs, frames, dims]; false: [bs, dims, frames]
specaug: SpecAugLFR
specaug_conf:
    apply_time_warp: false
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    lfr_rate: 6
    num_freq_mask: 1
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 12
    num_time_mask: 1
train_conf:
  accum_grad: 1
  grad_clip: 5
  max_epoch: 15
  keep_nbest_models: 10
  log_interval: 10
optim: adamw
optim_conf:
   lr: 0.0001
   weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 1000
dataset: AudioLLMQwenAudioDataset
dataset_conf:
    index_ds: IndexDSJsonl
    batch_sampler: CustomDistributedBatchSampler
    batch_type: example # example or length
    batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 3000 # filter samples if source_token_len+target_token_len > max_token_length,
    shuffle: True
    num_workers: 4
    preprocessor_text: TextPreprocessRemovePunctuation
    audio_adaptor_downsample_rate: ${audio_adaptor_conf.downsample_rate}
    audio_encoder_downsample_rate: 2
    prompt: "<|startoftranscription|><|zh|><|transcribe|><|zh|><|notimestamps|><|wo_itn|>"
tokenizer: HuggingfaceTokenizer
tokenizer_conf:
  unk_symbol: <unk>
  init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
examples/industrial_data_pretraining/paraformer/export.py
New file
@@ -0,0 +1,13 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4", export_model=True)
res = model.export(input=wav_file, type="onnx", quantize=False)
print(res)
examples/industrial_data_pretraining/paraformer/export.sh
New file
@@ -0,0 +1,18 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# method1, inference from model hub
export HYDRA_FULL_ERROR=1
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
python funasr/bin/export.py \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
++type="onnx" \
++quantize=false \
++device="cpu" \
++debug=false
funasr/auto/auto_model.py
@@ -14,6 +14,7 @@
import numpy as np
from tqdm import tqdm
from funasr.utils.misc import deep_update
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
@@ -23,12 +24,13 @@
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.utils import export_utils
try:
    from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
    from funasr.models.campplus.cluster_backend import ClusterBackend
except:
    print("If you want to use the speaker diarization, please `pip install hdbscan`")
import pdb
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
    """
@@ -41,7 +43,7 @@
    """
    data_list = []
    key_list = []
    filelist = [".scp", ".txt", ".json", ".jsonl"]
    filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
    
    chars = string.ascii_letters + string.digits
    if isinstance(data_in, str) and data_in.startswith('http'): # url
@@ -98,7 +100,9 @@
    def __init__(self, **kwargs):
        if not kwargs.get("disable_log", True):
            tables.print()
        if kwargs.get("export_model", False):
            os.environ['EXPORTING_MODEL'] = 'TRUE'
        model, kwargs = self.build_model(**kwargs)
        
        # if vad_model is not None, build vad model else None
@@ -203,7 +207,7 @@
    
    def __call__(self, *args, **cfg):
        kwargs = self.kwargs
        kwargs.update(cfg)
        deep_update(kwargs, cfg)
        res = self.model(*args, kwargs)
        return res
@@ -216,7 +220,7 @@
        
    def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
        kwargs = self.kwargs if kwargs is None else kwargs
        kwargs.update(cfg)
        deep_update(kwargs, cfg)
        model = self.model if model is None else model
        model.eval()
@@ -279,7 +283,7 @@
    def inference_with_vad(self, input, input_len=None, **cfg):
        kwargs = self.kwargs
        # step.1: compute the vad model
        self.vad_kwargs.update(cfg)
        deep_update(self.vad_kwargs, cfg)
        beg_vad = time.time()
        res = self.inference(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
        end_vad = time.time()
@@ -287,7 +291,7 @@
        # step.2 compute asr model
        model = self.model
        kwargs.update(cfg)
        deep_update(kwargs, cfg)
        batch_size = int(kwargs.get("batch_size_s", 300))*1000
        batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
        kwargs["batch_size"] = batch_size
@@ -399,7 +403,7 @@
                    if return_raw_text:
                        result['raw_text'] = ''
                else:
                    self.punc_kwargs.update(cfg)
                    deep_update(self.punc_kwargs, cfg)
                    punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
                    raw_text = copy.copy(result["text"])
                    if return_raw_text: result['raw_text'] = raw_text
@@ -467,3 +471,35 @@
        #                      f"time_escape_all: {time_escape_total_all_samples:0.3f}")
        return results_ret_list
    def export(self, input=None,
               type : str = "onnx",
               quantize: bool = False,
               fallback_num: int = 5,
               calib_num: int = 100,
               opset_version: int = 14,
               **cfg):
        os.environ['EXPORTING_MODEL'] = 'TRUE'
        kwargs = self.kwargs
        deep_update(kwargs, cfg)
        del kwargs["model"]
        model = self.model
        model.eval()
        batch_size = 1
        key_list, data_list = prepare_data_iterator(input, input_len=None, data_type=kwargs.get("data_type", None), key=None)
        with torch.no_grad():
            if type == "onnx":
                export_dir = export_utils.export_onnx(
                                        model=model,
                                        data_in=data_list,
                                        **kwargs)
            else:
                export_dir = export_utils.export_torchscripts(
                                        model=model,
                                        data_in=data_list,
                                        **kwargs)
        return export_dir
funasr/bin/export.py
New file
@@ -0,0 +1,40 @@
import os
import hydra
import logging
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.auto.auto_model import AutoModel
@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
    def to_plain_list(cfg_item):
        if isinstance(cfg_item, ListConfig):
            return OmegaConf.to_container(cfg_item, resolve=True)
        elif isinstance(cfg_item, DictConfig):
            return {k: to_plain_list(v) for k, v in cfg_item.items()}
        else:
            return cfg_item
    kwargs = to_plain_list(cfg)
    log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
    logging.basicConfig(level=log_level)
    if kwargs.get("debug", False):
        import pdb; pdb.set_trace()
    model = AutoModel(export_model=True, **kwargs)
    res = model.export(input=kwargs.get("input", None),
                       type=kwargs.get("type", "onnx"),
                       quantize=kwargs.get("quantize", False),
                       fallback_num=kwargs.get("fallback-num", 5),
                       calib_num=kwargs.get("calib_num", 100),
                       opset_version=kwargs.get("opset_version", 14),
                       )
    print(res)
if __name__ == '__main__':
    main_hydra()
funasr/datasets/llm_datasets/preprocessor.py
@@ -18,7 +18,7 @@
@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
class TextPreprocessSegDict(nn.Module):
class TextPreprocessRemovePunctuation(nn.Module):
    def __init__(self,
                 **kwargs):
        super().__init__()
funasr/datasets/llm_datasets_qwenaudio/__init__.py
funasr/datasets/llm_datasets_qwenaudio/datasets.py
New file
@@ -0,0 +1,150 @@
import torch
import copy
from funasr.register import tables
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
@tables.register("dataset_classes", "AudioLLMQwenAudioDataset")
class AudioLLMQwenAudioDataset(torch.utils.data.Dataset):
    """
    AudioLLMDataset
    """
    def __init__(self,
                 path,
                 index_ds: str = None,
                 frontend=None,
                 tokenizer=None,
                 int_pad_value: int = -1,
                 float_pad_value: float = 0.0,
                 **kwargs):
        super().__init__()
        index_ds_class = tables.index_ds_classes.get(index_ds)
        self.index_ds = index_ds_class(path, **kwargs)
        preprocessor_speech = kwargs.get("preprocessor_speech", None)
        if preprocessor_speech:
            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
        self.preprocessor_speech = preprocessor_speech
        preprocessor_text = kwargs.get("preprocessor_text", None)
        if preprocessor_text:
            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
        self.preprocessor_text = preprocessor_text
        self.frontend = frontend
        self.fs = 16000 if frontend is None else frontend.fs
        self.data_type = "sound"
        self.tokenizer = tokenizer
        self.float_pad_value = float_pad_value
        self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
        # self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
        self.prompt_af = ""
        self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
        self.int_pad_value = self.IGNORE_INDEX
        self.audio_adaptor_downsample_rate = kwargs.get("audio_adaptor_downsample_rate", 5)
        self.audio_encoder_downsample_rate = kwargs.get("audio_encoder_downsample_rate", 2)
        self.prompt_template = "{}"
        self.answer_template = "{}"
    def get_source_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_source_len(item)
    def get_target_len(self, index):
        item = self.index_ds[index]
        return self.index_ds.get_target_len(item)
    def __len__(self):
        return len(self.index_ds)
    def __getitem__(self, index):
        item = self.index_ds[index]
        # import pdb;
        # pdb.set_trace()
        source = item["source"]
        data_src = load_audio_text_image_video(source, fs=self.fs)
        if self.preprocessor_speech:
            data_src = self.preprocessor_speech(data_src, fs=self.fs)
        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend,
                                               is_final=True)  # speech: [b, T, d]
        speech = speech.squeeze(0)
        audio_pseudo_length = (speech.shape[0] + 1) // self.audio_adaptor_downsample_rate // self.audio_encoder_downsample_rate
        audio_pseudo = torch.full((audio_pseudo_length,), -1) # placeholder
        target = item["target"]
        if self.preprocessor_text:
            target = self.preprocessor_text(target)
        self.prompt_pre = self.prompt_template.format(self.prompt)
        prompt_ids_pre = self.tokenizer.encode(self.prompt_pre)  # [bos,prompt]
        prompt_pre_length = len(prompt_ids_pre)
        # input
        input = self.answer_template.format(target.lower())
        prompt_input = "{}{}".format(self.prompt_pre, input)
        prompt_input_ids = self.tokenizer.encode(prompt_input) # [bos, prompt, input]
        # audio_length = len(prompt_input_ids) - prompt_pre_length
        input_ids = prompt_input_ids + [self.tokenizer.pad_token_id] # [bos, prompt, input, pad]
        input_ids_length = len(input_ids)
        input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [bos, prompt, input, pad]
        input_ids = torch.cat((audio_pseudo, input_ids)) # [audio, bos, prompt, input, pad]
        # input_ids[:audio_pseudo_length] = -1 # [-1, bos, prompt, input, pad]
        attention_mask = input_ids.ge(-1)  # [true, true, true, true, true], length mask
        # input_ids[prompt_pre_length:] = -1  # [bos, prompt,-1,-1]
        # attention_mask = input_ids.ge(-1)  # [true, true, true, true], length mask
        # label
        answer = self.answer_template.format(target.lower())
        prompt_answer = "{}{}".format(self.prompt_pre, answer)
        prompt_answer_ids = self.tokenizer.encode(prompt_answer)
        # answer_length = len(prompt_answer_ids) - prompt_pre_length
        labels_ids = copy.deepcopy(prompt_answer_ids) + [self.tokenizer.eos_token_id]
        labels_ids = torch.tensor(labels_ids, dtype=torch.int64)  # [bos, prompt, answer, eos]
        labels_ids = torch.cat((audio_pseudo, labels_ids))  # [audio, bos, prompt, answer, eos]
        labels_ids[:audio_pseudo_length+prompt_pre_length] = -1 # [-1, -1, -1, answer, eos]
        # labels_ids[:prompt_pre_length] = -1  # [-1, -1, input, eos]
        label_mask = labels_ids.ge(0)  # [false, false, false, true, true]
        labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100, -100, -100, answer, eos]
        # audio_mask for input_ids
        audio_mask = [1]*audio_pseudo_length + [0]*input_ids_length
        audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
        ids = self.tokenizer.encode(target)  # token ids is different from labels_ids
        text = torch.tensor(ids, dtype=torch.int64)
        text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
        return {"speech": speech,
                "speech_lengths": speech_lengths,
                "text": text,
                "text_lengths": text_lengths,
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels_ids": labels_ids,
                "label_mask": label_mask,
                "audio_mask": audio_mask,
                }
    def collator(self, samples: list = None):
        outputs = {}
        for sample in samples:
            for key in sample.keys():
                if key not in outputs:
                    outputs[key] = []
                outputs[key].append(sample[key])
        for key, data_list in outputs.items():
            if isinstance(data_list[0], torch.Tensor):
                if data_list[0].dtype == torch.int64:
                    pad_value = self.int_pad_value
                else:
                    pad_value = self.float_pad_value
                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
        return outputs
funasr/datasets/llm_datasets_vicuna/preprocessor.py
File was deleted
funasr/models/fsmn_vad_streaming/encoder.py
@@ -1,5 +1,6 @@
from typing import Tuple, Dict
import copy
import os
import numpy as np
import torch
@@ -134,6 +135,25 @@
        x3 = self.affine(x2)
        x4 = self.relu(x3)
        return x4
class BasicBlock_export(nn.Module):
    def __init__(self,
                 model,
                 ):
        super(BasicBlock_export, self).__init__()
        self.linear = model.linear
        self.fsmn_block = model.fsmn_block
        self.affine = model.affine
        self.relu = model.relu
    def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
        x = self.linear(input)  # B T D
        # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
        # if cache_layer_name not in in_cache:
        #     in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
        x, out_cache = self.fsmn_block(x, in_cache)
        x = self.affine(x)
        x = self.relu(x)
        return x, out_cache
class FsmnStack(nn.Sequential):
@@ -192,6 +212,12 @@
        self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
        self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
        self.softmax = nn.Softmax(dim=-1)
        # export onnx or torchscripts
        if "EXPORTING_MODEL" in os.environ and os.environ['EXPORTING_MODEL'] == 'TRUE':
            for i, d in enumerate(self.fsmn):
                if isinstance(d, BasicBlock):
                    self.fsmn[i] = BasicBlock_export(d)
    def fuse_modules(self):
        pass
@@ -218,6 +244,32 @@
        return x7
    def export_forward(
            self,
            input: torch.Tensor,
            *args,
    ):
        """
        Args:
            input (torch.Tensor): Input tensor (B, T, D)
            in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
            {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
        """
        x = self.in_linear1(input)
        x = self.in_linear2(x)
        x = self.relu(x)
        # x4 = self.fsmn(x3, in_cache)  # self.in_cache will update automatically in self.fsmn
        out_caches = list()
        for i, d in enumerate(self.fsmn):
            in_cache = args[i]
            x, out_cache = d(x, in_cache)
            out_caches.append(out_cache)
        x = self.out_linear1(x)
        x = self.out_linear2(x)
        x = self.softmax(x)
        return x, out_caches
'''
one deep fsmn layer
funasr/models/fsmn_vad_streaming/model.py
@@ -284,6 +284,7 @@
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(**encoder_conf)
        self.encoder = encoder
        self.encoder_conf = encoder_conf
    
    def ResetDetection(self, cache: dict = {}):
        cache["stats"].continous_silence_frame_count = 0
@@ -642,6 +643,45 @@
        
        return results, meta_data
    
    def export(self, **kwargs):
        self.forward = self._export_forward
        return self
    def _export_forward(self, feats: torch.Tensor, *args, **kwargs):
        scores, out_caches = self.encoder.export_forward(feats, *args)
        return scores, out_caches
    def export_dummy_inputs(self, data_in=None, frame=30):
        if data_in is None:
            speech = torch.randn(1, frame, self.encoder_conf.get("input_dim"))
        else:
            speech = None # Undo
        cache_frames = self.encoder_conf.get("lorder") + self.encoder_conf.get("rorder") - 1
        in_cache0 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
        in_cache1 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
        in_cache2 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
        in_cache3 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
        return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
    def export_input_names(self):
        return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
    def export_output_names(self):
        return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
    def export_dynamic_axes(self):
        return {
            'speech': {
                1: 'feats_length'
            },
        }
    def export_name(self, ):
        return "model.onnx"
    def DetectCommonFrames(self, cache: dict = {}) -> int:
        if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
            return 0
funasr/models/llm_asr_nar/model.py
@@ -218,6 +218,9 @@
    
        audio_mask = kwargs.get("audio_mask", None)
        audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
        text_token_int = kwargs.get("text_token_int", None)
        if audio_token_lengths is None:
            audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        enc, enc_lens = self.audio_encoder.encode(**batch)
@@ -259,7 +262,13 @@
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
                                                            tokenizer=None)
            if len(kwargs.get("data_type")) > 1:
                audio_sample_list, text_token_int_list = audio_sample_list
                text_token_int = text_token_int_list[0].replace(" ", "")
                text_token_int = tokenizer.encode(text_token_int)
            else:
                text_token_int = None
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
@@ -272,7 +281,7 @@
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text_token_int=text_token_int)
        # adaptor
        encoder_out = self.adaptor(encoder_out)
funasr/models/paraformer/cif_predictor.py
@@ -155,7 +155,7 @@
                 tf2torch_tensor_name_prefix_tf="seq2seq/cif",
                 tail_mask=True,
                 ):
        super(CifPredictorV2, self).__init__()
        super().__init__()
        self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
        self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
@@ -375,6 +375,120 @@
        predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
        return predictor_alignments.detach(), predictor_alignments_length.detach()
@tables.register("predictor_classes", "CifPredictorV2Export")
class CifPredictorV2(torch.nn.Module):
    def __init__(self, model, **kwargs):
        super().__init__()
        self.pad = model.pad
        self.cif_conv1d = model.cif_conv1d
        self.cif_output = model.cif_output
        self.threshold = model.threshold
        self.smooth_factor = model.smooth_factor
        self.noise_threshold = model.noise_threshold
        self.tail_threshold = model.tail_threshold
    def forward(self, hidden: torch.Tensor,
                mask: torch.Tensor,
                ):
        alphas, token_num = self.forward_cnn(hidden, mask)
        mask = mask.transpose(-1, -2).float()
        mask = mask.squeeze(-1)
        hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
        acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
        return acoustic_embeds, token_num, alphas, cif_peak
    def forward_cnn(self, hidden: torch.Tensor,
                    mask: torch.Tensor,
                    ):
        h = hidden
        context = h.transpose(1, 2)
        queries = self.pad(context)
        output = torch.relu(self.cif_conv1d(queries))
        output = output.transpose(1, 2)
        output = self.cif_output(output)
        alphas = torch.sigmoid(output)
        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
        mask = mask.transpose(-1, -2).float()
        alphas = alphas * mask
        alphas = alphas.squeeze(-1)
        token_num = alphas.sum(-1)
        return alphas, token_num
    def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
        b, t, d = hidden.size()
        tail_threshold = self.tail_threshold
        zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
        ones_t = torch.ones_like(zeros_t)
        mask_1 = torch.cat([mask, zeros_t], dim=1)
        mask_2 = torch.cat([ones_t, mask], dim=1)
        mask = mask_2 - mask_1
        tail_threshold = mask * tail_threshold
        alphas = torch.cat([alphas, zeros_t], dim=1)
        alphas = torch.add(alphas, tail_threshold)
        zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
        hidden = torch.cat([hidden, zeros], dim=1)
        token_num = alphas.sum(dim=-1)
        token_num_floor = torch.floor(token_num)
        return hidden, alphas, token_num_floor
@torch.jit.script
def cif_export(hidden, alphas, threshold: float):
    batch_size, len_time, hidden_size = hidden.size()
    threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
    # loop varss
    integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
    frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
    # intermediate vars along time
    list_fires = []
    list_frames = []
    for t in range(len_time):
        alpha = alphas[:, t]
        distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
        integrate += alpha
        list_fires.append(integrate)
        fire_place = integrate >= threshold
        integrate = torch.where(fire_place,
                                integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
                                integrate)
        cur = torch.where(fire_place,
                          distribution_completion,
                          alpha)
        remainds = alpha - cur
        frame += cur[:, None] * hidden[:, t, :]
        list_frames.append(frame)
        frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
                            remainds[:, None] * hidden[:, t, :],
                            frame)
    fires = torch.stack(list_fires, 1)
    frames = torch.stack(list_frames, 1)
    fire_idxs = fires >= threshold
    frame_fires = torch.zeros_like(hidden)
    max_label_len = frames[0, fire_idxs[0]].size(0)
    for b in range(batch_size):
        frame_fire = frames[b, fire_idxs[b]]
        frame_len = frame_fire.size(0)
        frame_fires[b, :frame_len, :] = frame_fire
        if frame_len >= max_label_len:
            max_label_len = frame_len
    frame_fires = frame_fires[:, :max_label_len, :]
    return frame_fires, fires
class mae_loss(torch.nn.Module):
funasr/models/paraformer/decoder.py
@@ -581,6 +581,177 @@
        return y, new_cache
class DecoderLayerSANMExport(torch.nn.Module):
    def __init__(
        self,
        model
    ):
        super().__init__()
        self.self_attn = model.self_attn
        self.src_attn = model.src_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
        self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
        self.size = model.size
    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
        if self.src_attn is not None:
            residual = x
            x = self.norm3(x)
            x = residual + self.src_attn(x, memory, memory_mask)
        return x, tgt_mask, memory, memory_mask, cache
@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
class ParaformerSANMDecoderExport(torch.nn.Module):
    def __init__(self, model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True, ):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.multihead_att import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.multihead_att import MultiHeadedAttentionCrossAttExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANMExport(d)
        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANMExport(d)
        for i, d in enumerate(self.model.decoders3):
            self.model.decoders3[i] = DecoderLayerSANMExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
            x, tgt_mask, memory, memory_mask
        )
        if self.model.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
            x, tgt_mask, memory, memory_mask
        )
        x = self.after_norm(x)
        x = self.output_layer(x)
        return x, ys_in_lens
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
            for _ in range(cache_num)
        ]
        return (tgt, memory, pre_acoustic_embeds, cache)
    def is_optimizable(self):
        return True
    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
               + ['cache_%d' % i for i in range(cache_num)]
    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['y'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            'tgt': {
                0: 'tgt_batch',
                1: 'tgt_length'
            },
            'memory': {
                0: 'memory_batch',
                1: 'memory_length'
            },
            'pre_acoustic_embeds': {
                0: 'acoustic_embeds_batch',
                1: 'acoustic_embeds_length',
            }
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update({
            'cache_%d' % d: {
                0: 'cache_%d_batch' % d,
                2: 'cache_%d_length' % d
            }
            for d in range(cache_num)
        })
        return ret
@tables.register("decoder_classes", "ParaformerSANDecoder")
class ParaformerSANDecoder(BaseTransformerDecoder):
funasr/models/paraformer/model.py
@@ -549,3 +549,80 @@
                
        return results, meta_data
    def export(
        self,
        max_seq_len=512,
        **kwargs,
    ):
        onnx = kwargs.get("onnx", True)
        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
        self.encoder = encoder_class(self.encoder, onnx=onnx)
        predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
        self.predictor = predictor_class(self.predictor, onnx=onnx)
        decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
        self.decoder = decoder_class(self.decoder, onnx=onnx)
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        self.forward = self._export_forward
        return self
    def _export_forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        # batch = to_device(batch, device=self.device)
        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
        pre_token_length = pre_token_length.floor().type(torch.int32)
        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        # sample_ids = decoder_out.argmax(dim=-1)
        return decoder_out, pre_token_length
    def export_dummy_inputs(self):
        speech = torch.randn(2, 30, 560)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        return (speech, speech_lengths)
    def export_input_names(self):
        return ['speech', 'speech_lengths']
    def export_output_names(self):
        return ['logits', 'token_num']
    def export_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }
    def export_name(self, ):
        return "model.onnx"
funasr/models/qwen_audio/audio.py
New file
@@ -0,0 +1,446 @@
import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, List
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from subprocess import CalledProcessError, run, Popen, PIPE
import os
from functools import lru_cache
from typing import Optional, Union
from funasr.register import tables
def exact_div(x, y):
    assert x % y == 0
    return x // y
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)  # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2  # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)  # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)  # 20ms per audio token
def get_T_after_cnn(L_in, dilation=1):
    for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
        L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
        L_out = 1 + L_out // stride
        L_in = L_out
    return L_out
def load_bytesio_audio(content, sr: int = SAMPLE_RATE):
    cmd = [
        "ffmpeg",
        "-nostdin",
        "-threads", "0",
        "-i", "pipe:",
        "-f", "s16le",
        "-ac", "1",
        "-acodec", "pcm_s16le",
        "-ar", str(sr),
        "pipe:"
    ]
    p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, bufsize=-1)
    out, _ = p.communicate(input=content)
    return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def load_audio(file: str, sr: int = SAMPLE_RATE):
    """
    Open an audio file and read as mono waveform, resampling as necessary
    Parameters
    ----------
    file: str
        The audio file to open
    sr: int
        The sample rate to resample the audio if necessary
    Returns
    -------
    A NumPy array containing the audio waveform, in float32 dtype.
    """
    # This launches a subprocess to decode audio while down-mixing
    # and resampling as necessary.  Requires the ffmpeg CLI in PATH.
    # fmt: off
    cmd = [
        "ffmpeg",
        "-nostdin",
        "-threads", "0",
        "-i", file,
        "-f", "s16le",
        "-ac", "1",
        "-acodec", "pcm_s16le",
        "-ar", str(sr),
        "-"
    ]
    # fmt: on
    try:
        out = run(cmd, capture_output=True, check=True).stdout
    except CalledProcessError as e:
        raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
    return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
    """
    Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
    """
    if torch.is_tensor(array):
        if array.shape[axis] > length:
            array = array.index_select(
                dim=axis, index=torch.arange(length, device=array.device)
            )
        if array.shape[axis] < length:
            pad_widths = [(0, 0)] * array.ndim
            pad_widths[axis] = (0, length - array.shape[axis])
            array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
    else:
        if array.shape[axis] > length:
            array = array.take(indices=range(length), axis=axis)
        if array.shape[axis] < length:
            pad_widths = [(0, 0)] * array.ndim
            pad_widths[axis] = (0, length - array.shape[axis])
            array = np.pad(array, pad_widths)
    return array
def trim(array, length: int = N_SAMPLES, *, axis: int = -1):
    """
    Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
    """
    if torch.is_tensor(array):
        if array.shape[axis] > length:
            array = array.index_select(
                dim=axis, index=torch.arange(length, device=array.device)
            )
    else:
        if array.shape[axis] > length:
            array = array.take(indices=range(length), axis=axis)
    return array
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
    """
    load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
    Allows decoupling librosa dependency; saved using:
        np.savez_compressed(
            "mel_filters.npz",
            mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
        )
    """
    assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
    with np.load(
        os.path.join(os.path.dirname(__file__), "mel_filters.npz") # todo
        # os.path.join("assets", "mel_filters.npz")
    ) as f:
        return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
    audio: Union[str, np.ndarray, torch.Tensor],
    n_mels: int = N_MELS,
    padding: int = 0,
    device: Optional[Union[str, torch.device]] = None,
):
    """
    Compute the log-Mel spectrogram of
    Parameters
    ----------
    audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
        The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
    n_mels: int
        The number of Mel-frequency filters, only 80 is supported
    padding: int
        Number of zero samples to pad to the right
    device: Optional[Union[str, torch.device]]
        If given, the audio tensor is moved to this device before STFT
    Returns
    -------
    torch.Tensor, shape = (80, n_frames)
        A Tensor that contains the Mel spectrogram
    """
    if not torch.is_tensor(audio):
        if isinstance(audio, str):
            audio = load_audio(audio)
        audio = torch.from_numpy(audio)
    if device is not None:
        audio = audio.to(device)
    if padding > 0:
        audio = F.pad(audio, (0, padding))
    window = torch.hann_window(N_FFT).to(audio.device)
    stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
    magnitudes = stft[..., :-1].abs() ** 2
    filters = mel_filters(audio.device, n_mels)
    mel_spec = filters @ magnitudes
    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec
@dataclass
class ModelDimensions:
    n_mels: int
    n_audio_ctx: int
    n_audio_state: int
    n_audio_head: int
    n_audio_layer: int
    n_vocab: int
    n_text_ctx: int
    n_text_state: int
    n_text_head: int
    n_text_layer: int
class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        # return super().forward(x.float()).type(x.dtype)
        return super().forward(x).type(x.dtype)
class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )
class Conv1d(nn.Conv1d):
    def _conv_forward(
        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
    ) -> Tensor:
        return super()._conv_forward(
            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
        )
def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        q = self.query(x)
        if kv_cache is None or xa is None or self.key not in kv_cache:
            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
            # otherwise, perform key/value projections for self- or cross-attention as usual.
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
            k = kv_cache[self.key]
            v = kv_cache[self.value]
        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk
    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        qk = q @ k
        if mask is not None:
            qk += mask
        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()
        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)
        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x
@tables.register("encoder_classes", "QwenAudioEncoder")
class QwenAudioEncoder(nn.Module):
    def __init__(
            self,
            n_mels: int,
            n_ctx: int,
            n_state: int,
            n_head: int,
            n_layer: int,
            output_dim: int = 512,
            avg_pool: bool = True,
            add_audio_bos_eos_token: bool = True,
            **kwargs
    ):
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)
        if avg_pool:
            self.avg_pooler = nn.AvgPool1d(2, stride=2)
        else:
            self.avg_pooler = None
        self.proj = nn.Linear(n_state, output_dim)
        if add_audio_bos_eos_token:
            self.audio_bos_eos_token = nn.Embedding(2, output_dim)
        else:
            self.audio_bos_eos_token = None
        self.output_dim = output_dim
        self.n_head = n_head
    def forward(self, x: Tensor, padding_mask: Tensor=None, audio_lengths: Tensor=None):
        """
        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            the mel spectrogram of the audio
        """
        x = x.to(dtype=self.conv1.weight.dtype,
                 device=self.conv1.weight.device)
        if audio_lengths is not None:
            input_mel_len = audio_lengths[:,0] * 2
            max_mel_len_in_batch = input_mel_len.max()
            x = x[:, :, :max_mel_len_in_batch]
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)  # B, L, D
        bsz = x.size(0)
        src_len = x.size(1)
        self.input_positional_embedding = self.positional_embedding[:src_len]
        assert x.shape[1:] == self.input_positional_embedding.shape, f"incorrect audio shape: {x.shape[1:], self.input_positional_embedding.shape}"
        x = (x + self.input_positional_embedding).to(x.dtype)
        if padding_mask is not None:
            padding_mask = padding_mask.to(dtype=self.conv1.weight.dtype,
                     device=self.conv1.weight.device)
            batch_src_len = padding_mask.size(1)
            x = x[:, :batch_src_len, :]
            padding_mask = padding_mask.view(
                bsz, -1, batch_src_len
            )
            padding_mask_ = padding_mask.all(1)
            x[padding_mask_] = 0
            key_padding_mask = padding_mask_.view(bsz, 1, 1, batch_src_len). \
                expand(-1, self.n_head, -1, -1).reshape(bsz, self.n_head, 1, batch_src_len)
            new_padding_mask = torch.zeros_like(key_padding_mask, dtype=x.dtype)
            padding_mask = new_padding_mask.masked_fill(key_padding_mask, float("-inf"))
        for block in self.blocks:
            x = block(x, mask=padding_mask)
        if self.avg_pooler:
            x = x.permute(0, 2, 1)
            x = self.avg_pooler(x)
            x = x.permute(0, 2, 1)
        x = self.ln_post(x)
        x = self.proj(x)
        if self.audio_bos_eos_token is not None:
            bos = self.audio_bos_eos_token.weight[0][None, :]
            eos = self.audio_bos_eos_token.weight[1][None, :]
        else:
            bos, eos = None, None
        return x, bos, eos
    def encode(self, input_audios: Tensor, input_audio_lengths: Tensor, audio_span_tokens: List):
        real_input_audio_lens = input_audio_lengths[:, 0].tolist()
        max_len_in_batch = max(real_input_audio_lens)
        padding_mask = torch.ones([input_audios.size(0), max_len_in_batch]).to(dtype=self.conv1.weight.dtype,
                                                                               device=self.conv1.weight.device)
        for index in range(len(input_audios)):
            padding_mask[index, :input_audio_lengths[index][0].item()] = 0
        x, bos, eos = self(input_audios, padding_mask,input_audio_lengths)
        output_audios = []
        for i in range(len(audio_span_tokens)):
            audio_span = audio_span_tokens[i]
            audio = x[i][:audio_span-2]
            if bos is not None:
                audio = torch.concat([bos, audio, eos])
            assert len(audio) == audio_span
            output_audios.append(audio)
        return output_audios
funasr/models/qwen_audio/model.py
@@ -60,7 +60,7 @@
        inputs = self.tokenizer(query, return_tensors='pt', audio_info=audio_info)
        inputs = inputs.to(self.model.device)
        pred = self.model.generate(**inputs, audio_info=audio_info)
        response = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False, audio_info=audio_info)
        response = self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=False, audio_info=audio_info)
        results = []
        result_i = {"key": key[0], "text": response}
funasr/models/sanm/attention.py
@@ -303,6 +303,64 @@
        att_outs = self.forward_attention(v_h, scores, None)
        return att_outs + fsmn_memory, cache
class MultiHeadedAttentionSANMExport(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.d_k = model.d_k
        self.h = model.h
        self.linear_out = model.linear_out
        self.linear_q_k_v = model.linear_q_k_v
        self.fsmn_block = model.fsmn_block
        self.pad_fn = model.pad_fn
        self.attn = None
        self.all_head_size = self.h * self.d_k
    def forward(self, x, mask):
        mask_3d_btd, mask_4d_bhlt = mask
        q_h, k_h, v_h, v = self.forward_qkv(x)
        fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
        q_h = q_h * self.d_k**(-0.5)
        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
        att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
        return att_outs + fsmn_memory
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)
    def forward_qkv(self, x):
        q_k_v = self.linear_q_k_v(x)
        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
        q_h = self.transpose_for_scores(q)
        k_h = self.transpose_for_scores(k)
        v_h = self.transpose_for_scores(v)
        return q_h, k_h, v_h, v
    def forward_fsmn(self, inputs, mask):
        # b, t, d = inputs.size()
        # mask = torch.reshape(mask, (b, -1, 1))
        inputs = inputs * mask
        x = inputs.transpose(1, 2)
        x = self.pad_fn(x)
        x = self.fsmn_block(x)
        x = x.transpose(1, 2)
        x = x + inputs
        x = x * mask
        return x
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.linear_out(context_layer)  # (batch, time1, d_model)
class MultiHeadedAttentionSANMDecoder(nn.Module):
funasr/models/sanm/encoder.py
@@ -456,3 +456,135 @@
            return (xs_pad, intermediate_outs), None, None
        return xs_pad, ilens, None
class EncoderLayerSANMExport(nn.Module):
    def __init__(
        self,
        model,
    ):
        """Construct an EncoderLayer object."""
        super().__init__()
        self.self_attn = model.self_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2
        self.in_size = model.in_size
        self.size = model.size
    def forward(self, x, mask):
        residual = x
        x = self.norm1(x)
        x = self.self_attn(x, mask)
        if self.in_size == self.size:
            x = x + residual
        residual = x
        x = self.norm2(x)
        x = self.feed_forward(x)
        x = x + residual
        return x, mask
@tables.register("encoder_classes", "SANMEncoderExport")
class SANMEncoderExport(nn.Module):
    def __init__(
        self,
        model,
        max_seq_len=512,
        feats_dim=560,
        model_name='encoder',
        onnx: bool = True,
    ):
        super().__init__()
        self.embed = model.embed
        if isinstance(self.embed, StreamSinusoidalPositionEncoder):
            self.embed = None
        self.model = model
        self.feats_dim = feats_dim
        self._output_size = model._output_size
        from funasr.utils.torch_function import MakePadMask
        from funasr.utils.torch_function import sequence_mask
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport
        if hasattr(model, 'encoders0'):
            for i, d in enumerate(self.model.encoders0):
                if isinstance(d.self_attn, MultiHeadedAttentionSANM):
                    d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
                self.model.encoders0[i] = EncoderLayerSANMExport(d)
        for i, d in enumerate(self.model.encoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
                d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
            self.model.encoders[i] = EncoderLayerSANMExport(d)
        self.model_name = model_name
        self.num_heads = model.encoders[0].self_attn.h
        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(self,
                speech: torch.Tensor,
                speech_lengths: torch.Tensor,
                online: bool = False
                ):
        if not online:
            speech = speech * self._output_size ** 0.5
        mask = self.make_pad_mask(speech_lengths)
        mask = self.prepare_mask(mask)
        if self.embed is None:
            xs_pad = speech
        else:
            xs_pad = self.embed(speech)
        encoder_outs = self.model.encoders0(xs_pad, mask)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        encoder_outs = self.model.encoders(xs_pad, mask)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        xs_pad = self.model.after_norm(xs_pad)
        return xs_pad, speech_lengths
    def get_output_size(self):
        return self.model.encoders[0].size
    def get_dummy_inputs(self):
        feats = torch.randn(1, 100, self.feats_dim)
        return (feats)
    def get_input_names(self):
        return ['feats']
    def get_output_names(self):
        return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
    def get_dynamic_axes(self):
        return {
            'feats': {
                1: 'feats_length'
            },
            'encoder_out': {
                1: 'enc_out_length'
            },
            'predictor_weight': {
                1: 'pre_out_length'
            }
        }
funasr/models/sanm/multihead_att.py
New file
@@ -0,0 +1,243 @@
import os
import math
import torch
import torch.nn as nn
class MultiHeadedAttentionSANMExport(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.d_k = model.d_k
        self.h = model.h
        self.linear_out = model.linear_out
        self.linear_q_k_v = model.linear_q_k_v
        self.fsmn_block = model.fsmn_block
        self.pad_fn = model.pad_fn
        self.attn = None
        self.all_head_size = self.h * self.d_k
    def forward(self, x, mask):
        mask_3d_btd, mask_4d_bhlt = mask
        q_h, k_h, v_h, v = self.forward_qkv(x)
        fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
        q_h = q_h * self.d_k**(-0.5)
        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
        att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
        return att_outs + fsmn_memory
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)
    def forward_qkv(self, x):
        q_k_v = self.linear_q_k_v(x)
        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
        q_h = self.transpose_for_scores(q)
        k_h = self.transpose_for_scores(k)
        v_h = self.transpose_for_scores(v)
        return q_h, k_h, v_h, v
    def forward_fsmn(self, inputs, mask):
        # b, t, d = inputs.size()
        # mask = torch.reshape(mask, (b, -1, 1))
        inputs = inputs * mask
        x = inputs.transpose(1, 2)
        x = self.pad_fn(x)
        x = self.fsmn_block(x)
        x = x.transpose(1, 2)
        x = x + inputs
        x = x * mask
        return x
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.linear_out(context_layer)  # (batch, time1, d_model)
def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
    x = x * mask
    x = x.transpose(1, 2)
    if cache is None:
        x = pad_fn(x)
    else:
        x = torch.cat((cache, x), dim=2)
        cache = x[:, :, -(kernel_size-1):]
    return x, cache
torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
if torch_version >= (1, 8):
    import torch.fx
    torch.fx.wrap('preprocess_for_attn')
class MultiHeadedAttentionSANMDecoderExport(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.fsmn_block = model.fsmn_block
        self.pad_fn = model.pad_fn
        self.kernel_size = model.kernel_size
        self.attn = None
    def forward(self, inputs, mask, cache=None):
        x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
        x = self.fsmn_block(x)
        x = x.transpose(1, 2)
        x = x + inputs
        x = x * mask
        return x, cache
class MultiHeadedAttentionCrossAttExport(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.d_k = model.d_k
        self.h = model.h
        self.linear_q = model.linear_q
        self.linear_k_v = model.linear_k_v
        self.linear_out = model.linear_out
        self.attn = None
        self.all_head_size = self.h * self.d_k
    def forward(self, x, memory, memory_mask):
        q, k, v = self.forward_qkv(x, memory)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        return self.forward_attention(v, scores, memory_mask)
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)
    def forward_qkv(self, x, memory):
        q = self.linear_q(x)
        k_v = self.linear_k_v(memory)
        k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
        q = self.transpose_for_scores(q)
        k = self.transpose_for_scores(k)
        v = self.transpose_for_scores(v)
        return q, k, v
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.linear_out(context_layer)  # (batch, time1, d_model)
class OnnxMultiHeadedAttention(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.d_k = model.d_k
        self.h = model.h
        self.linear_q = model.linear_q
        self.linear_k = model.linear_k
        self.linear_v = model.linear_v
        self.linear_out = model.linear_out
        self.attn = None
        self.all_head_size = self.h * self.d_k
    def forward(self, query, key, value, mask):
        q, k, v = self.forward_qkv(query, key, value)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        return self.forward_attention(v, scores, mask)
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)
    def forward_qkv(self, query, key, value):
        q = self.linear_q(query)
        k = self.linear_k(key)
        v = self.linear_v(value)
        q = self.transpose_for_scores(q)
        k = self.transpose_for_scores(k)
        v = self.transpose_for_scores(v)
        return q, k, v
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.linear_out(context_layer)  # (batch, time1, d_model)
class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
    def __init__(self, model):
        super().__init__(model)
        self.linear_pos = model.linear_pos
        self.pos_bias_u = model.pos_bias_u
        self.pos_bias_v = model.pos_bias_v
    def forward(self, query, key, value, pos_emb, mask):
        q, k, v = self.forward_qkv(query, key, value)
        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
        p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
        # (batch, head, time1, d_k)
        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
        # (batch, head, time1, d_k)
        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
        # compute attention score
        # first compute matrix a and matrix c
        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
        # (batch, head, time1, time2)
        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
        # compute matrix b and matrix d
        # (batch, head, time1, time1)
        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
        matrix_bd = self.rel_shift(matrix_bd)
        scores = (matrix_ac + matrix_bd) / math.sqrt(
            self.d_k
        )  # (batch, head, time1, time2)
        return self.forward_attention(v, scores, mask)
    def rel_shift(self, x):
        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=-1)
        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
        x = x_padded[:, :, 1:].view_as(x)[
            :, :, :, : x.size(-1) // 2 + 1
        ]  # only keep the positions from 0 to time2
        return x
    def forward_attention(self, value, scores, mask):
        scores = scores + mask
        self.attn = torch.softmax(scores, dim=-1)
        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)
        return self.linear_out(context_layer)  # (batch, time1, d_model)
funasr/models/whisper/template.yaml
@@ -11,29 +11,28 @@
    lsm_weight: 0.1
    length_normalized_loss: true
    hub: funasr # openai
    init_param_path: null # large-v2 or large-v3 if hub == "openai"
# only use for hub == funasr,
#  if hub == openai, whisper_dims is automaticall download
whisper_dims:
    'n_mels': 80
    'n_vocab': 51865
    'n_audio_ctx': 1500
    'n_audio_state': 1280
    'n_audio_head': 20
    'n_audio_layer': 32
    'n_text_ctx': 448
    'n_text_state': 1280
    'n_text_head': 20
    'n_text_layer': 32
#  if hub == openai, dims is automaticall download
dims:
    n_mels: 128
    n_vocab: 51866
    n_audio_ctx: 1500
    n_audio_state: 1280
    n_audio_head: 20
    n_audio_layer: 32
    n_text_ctx: 448
    n_text_state: 1280
    n_text_head: 20
    n_text_layer: 32
# frontend related
frontend: WhisperFrontend
frontend_conf:
    fs: 16000
    n_mels: 80
    n_mels: ${dims.n_mels}
    do_pad_trim: true
tokenizer: WhisperTokenizer
@@ -41,6 +40,6 @@
  language: null
  task: transcribe
  is_multilingual: true
  num_languages: 99
  num_languages: 100
scope_map: ['none', "model."]
scope_map: [none, "model."]
funasr/utils/export_utils.py
New file
@@ -0,0 +1,74 @@
import os
import torch
def export_onnx(model,
                data_in=None,
                type: str = "onnx",
                quantize: bool = False,
                fallback_num: int = 5,
                calib_num: int = 100,
                opset_version: int = 14,
                **kwargs):
    model_scripts = model.export(**kwargs)
    export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
    os.makedirs(export_dir, exist_ok=True)
    if not isinstance(model_scripts, (list, tuple)):
        model_scripts = (model_scripts,)
    for m in model_scripts:
        m.eval()
        _onnx(m,
              data_in=data_in,
              type=type,
              quantize=quantize,
              fallback_num=fallback_num,
              calib_num=calib_num,
              opset_version=opset_version,
              export_dir=export_dir,
              **kwargs
              )
        print("output dir: {}".format(export_dir))
    return export_dir
def _onnx(model,
            data_in=None,
            quantize: bool = False,
            opset_version: int = 14,
            export_dir:str = None,
            **kwargs):
    dummy_input = model.export_dummy_inputs()
    verbose = kwargs.get("verbose", False)
    export_name = model.export_name() if hasattr(model, "export_name") else "model.onnx"
    model_path = os.path.join(export_dir, export_name)
    torch.onnx.export(
        model,
        dummy_input,
        model_path,
        verbose=verbose,
        opset_version=opset_version,
        input_names=model.export_input_names(),
        output_names=model.export_output_names(),
        dynamic_axes=model.export_dynamic_axes()
    )
    if quantize:
        from onnxruntime.quantization import QuantType, quantize_dynamic
        import onnx
        quant_model_path = model_path.replace(".onnx", "_quant.onnx")
        if not os.path.exists(quant_model_path):
            onnx_model = onnx.load(model_path)
            nodes = [n.name for n in onnx_model.graph.node]
            nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m or 'bias_decoder' in m]
            quantize_dynamic(
                model_input=model_path,
                model_output=quant_model_path,
                op_types_to_quantize=['MatMul'],
                per_channel=True,
                reduce_range=False,
                weight_type=QuantType.QUInt8,
                nodes_to_exclude=nodes_to_exclude,
            )
funasr/utils/misc.py
@@ -46,3 +46,10 @@
                value = value.split(' ')
            ret_dict.append((key, value))
        return ret_dict
def deep_update(original, update):
    for key, value in update.items():
        if isinstance(value, dict) and key in original:
            deep_update(original[key], value)
        else:
            original[key] = value
funasr/utils/torch_function.py
New file
@@ -0,0 +1,80 @@
from typing import Optional
import torch
import torch.nn as nn
import numpy as np
class MakePadMask(nn.Module):
    def __init__(self, max_seq_len=512, flip=True):
        super().__init__()
        if flip:
            self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
        else:
            self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)
    def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
        """Make mask tensor containing indices of padded part.
        This implementation creates the same mask tensor with original make_pad_mask,
        which can be converted into onnx format.
        Dimension length of xs should be 2 or 3.
        """
        if length_dim == 0:
            raise ValueError("length_dim cannot be 0: {}".format(length_dim))
        if xs is not None and len(xs.shape) == 3:
            if length_dim == 1:
                lengths = lengths.unsqueeze(1).expand(
                    *xs.transpose(1, 2).shape[:2])
            else:
                lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
        if maxlen is not None:
            m = maxlen
        elif xs is not None:
            m = xs.shape[-1]
        else:
            m = torch.max(lengths)
        mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
        if length_dim == 1:
            return mask.transpose(1, 2)
        else:
            return mask
class sequence_mask(nn.Module):
    def __init__(self, max_seq_len=512, flip=True):
        super().__init__()
    def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
        if max_seq_len is None:
            max_seq_len = lengths.max()
        row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
        matrix = torch.unsqueeze(lengths, dim=-1)
        mask = row_vector < matrix
        return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
def normalize(input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None) -> torch.Tensor:
    if out is None:
        denom = input.norm(p, dim, keepdim=True).expand_as(input)
        return input / denom
    else:
        denom = input.norm(p, dim, keepdim=True).expand_as(input)
        return torch.div(input, denom, out=out)
def subsequent_mask(size: torch.Tensor):
    return torch.ones(size, size).tril()
def MakePadMask_test():
    feats_length = torch.tensor([10]).type(torch.long)
    mask_fn = MakePadMask()
    mask = mask_fn(feats_length)
    print(mask)
if __name__ == '__main__':
    MakePadMask_test()