From f2d8ded57f6403696001d39dd07a1396e5a03658 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 01:24:43 +0800
Subject: [PATCH] export onnx (#1455)
---
funasr/models/paraformer/model.py | 77 ++
funasr/bin/export.py | 40 +
funasr/models/llm_asr_nar/model.py | 13
funasr/models/sanm/multihead_att.py | 243 +++++++
examples/industrial_data_pretraining/paraformer/export.py | 13
funasr/models/qwen_audio/audio.py | 446 +++++++++++++
funasr/models/qwen_audio/model.py | 2
funasr/utils/export_utils.py | 74 ++
examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear.yaml | 94 ++
examples/industrial_data_pretraining/paraformer/export.sh | 18
funasr/datasets/llm_datasets_qwenaudio/__init__.py | 0
funasr/models/sanm/attention.py | 58 +
funasr/models/whisper/template.yaml | 31
funasr/utils/torch_function.py | 80 ++
funasr/models/fsmn_vad_streaming/model.py | 40 +
funasr/auto/auto_model.py | 54 +
funasr/datasets/llm_datasets/preprocessor.py | 2
funasr/datasets/llm_datasets_qwenaudio/datasets.py | 150 ++++
funasr/models/paraformer/cif_predictor.py | 116 +++
funasr/models/fsmn_vad_streaming/encoder.py | 52 +
examples/industrial_data_pretraining/fsmn_vad_streaming/export.py | 12
funasr/models/paraformer/decoder.py | 171 +++++
funasr/utils/misc.py | 7
/dev/null | 37 -
examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh | 19
funasr/models/sanm/encoder.py | 132 +++
26 files changed, 1,914 insertions(+), 67 deletions(-)
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
new file mode 100644
index 0000000..d259104
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.py
@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4", export_model=True)
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
new file mode 100644
index 0000000..0bb4617
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/export.sh
@@ -0,0 +1,19 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+
+# method1, inference from model hub
+export HYDRA_FULL_ERROR=1
+
+
+model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+model_revision="v2.0.4"
+
+python funasr/bin/export.py \
+++model=${model} \
+++model_revision=${model_revision} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu" \
+++debug=false
diff --git a/examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear.yaml b/examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear.yaml
new file mode 100644
index 0000000..4603f80
--- /dev/null
+++ b/examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear.yaml
@@ -0,0 +1,94 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+model: LLMASR
+model_conf:
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: true
+
+# encoder
+audio_encoder: iic/Whisper-large-v2 #iic/Whisper-large-v3
+audio_encoder_conf:
+ hub: ms
+ freeze: true
+ init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
+
+llm: Vicuna
+llm_conf:
+ hub: hf
+ freeze: true
+ init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
+
+audio_adaptor: Linear
+audio_adaptor_conf:
+ downsample_rate: 5
+ llm_dim: 4096
+ encoder_dim: 512
+
+# frontend related
+frontend: WhisperFrontend
+frontend_conf:
+ fs: 16000
+ whisper_model: large-v2
+ do_pad_trim: true
+ permute: true # true: [bs, frames, dims]; false: [bs, dims, frames]
+
+
+specaug: SpecAugLFR
+specaug_conf:
+ apply_time_warp: false
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ lfr_rate: 6
+ num_freq_mask: 1
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 12
+ num_time_mask: 1
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 15
+ keep_nbest_models: 10
+ log_interval: 10
+
+optim: adamw
+optim_conf:
+ lr: 0.0001
+ weight_decay: 0.000001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 1000
+
+dataset: AudioLLMQwenAudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: CustomDistributedBatchSampler
+ batch_type: example # example or length
+ batch_size: 4 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 3000 # filter samples if source_token_len+target_token_len > max_token_length,
+ shuffle: True
+ num_workers: 4
+ preprocessor_text: TextPreprocessRemovePunctuation
+ audio_adaptor_downsample_rate: ${audio_adaptor_conf.downsample_rate}
+ audio_encoder_downsample_rate: 2
+ prompt: "<|startoftranscription|><|zh|><|transcribe|><|zh|><|notimestamps|><|wo_itn|>"
+
+
+
+tokenizer: HuggingfaceTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+ init_param_path: "/nfs/maziyang.mzy/models/vicuna-7b-v1.5"
+
diff --git a/examples/industrial_data_pretraining/paraformer/export.py b/examples/industrial_data_pretraining/paraformer/export.py
new file mode 100644
index 0000000..613c3a9
--- /dev/null
+++ b/examples/industrial_data_pretraining/paraformer/export.py
@@ -0,0 +1,13 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
+
+model = AutoModel(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+ model_revision="v2.0.4", export_model=True)
+
+res = model.export(input=wav_file, type="onnx", quantize=False)
+print(res)
diff --git a/examples/industrial_data_pretraining/paraformer/export.sh b/examples/industrial_data_pretraining/paraformer/export.sh
new file mode 100644
index 0000000..9f45a5a
--- /dev/null
+++ b/examples/industrial_data_pretraining/paraformer/export.sh
@@ -0,0 +1,18 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+export HYDRA_FULL_ERROR=1
+
+
+model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model_revision="v2.0.4"
+
+python funasr/bin/export.py \
+++model=${model} \
+++model_revision=${model_revision} \
+++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+++type="onnx" \
+++quantize=false \
+++device="cpu" \
+++debug=false
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 9ae9f18..d7b6cb9 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -14,6 +14,7 @@
import numpy as np
from tqdm import tqdm
+from funasr.utils.misc import deep_update
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
@@ -23,12 +24,13 @@
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
-from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
+from funasr.utils import export_utils
try:
+ from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.models.campplus.cluster_backend import ClusterBackend
except:
print("If you want to use the speaker diarization, please `pip install hdbscan`")
-import pdb
+
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
"""
@@ -41,7 +43,7 @@
"""
data_list = []
key_list = []
- filelist = [".scp", ".txt", ".json", ".jsonl"]
+ filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
chars = string.ascii_letters + string.digits
if isinstance(data_in, str) and data_in.startswith('http'): # url
@@ -98,7 +100,9 @@
def __init__(self, **kwargs):
if not kwargs.get("disable_log", True):
tables.print()
-
+ if kwargs.get("export_model", False):
+ os.environ['EXPORTING_MODEL'] = 'TRUE'
+
model, kwargs = self.build_model(**kwargs)
# if vad_model is not None, build vad model else None
@@ -203,7 +207,7 @@
def __call__(self, *args, **cfg):
kwargs = self.kwargs
- kwargs.update(cfg)
+ deep_update(kwargs, cfg)
res = self.model(*args, kwargs)
return res
@@ -216,7 +220,7 @@
def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
kwargs = self.kwargs if kwargs is None else kwargs
- kwargs.update(cfg)
+ deep_update(kwargs, cfg)
model = self.model if model is None else model
model.eval()
@@ -279,7 +283,7 @@
def inference_with_vad(self, input, input_len=None, **cfg):
kwargs = self.kwargs
# step.1: compute the vad model
- self.vad_kwargs.update(cfg)
+ deep_update(self.vad_kwargs, cfg)
beg_vad = time.time()
res = self.inference(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg)
end_vad = time.time()
@@ -287,7 +291,7 @@
# step.2 compute asr model
model = self.model
- kwargs.update(cfg)
+ deep_update(kwargs, cfg)
batch_size = int(kwargs.get("batch_size_s", 300))*1000
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
kwargs["batch_size"] = batch_size
@@ -399,7 +403,7 @@
if return_raw_text:
result['raw_text'] = ''
else:
- self.punc_kwargs.update(cfg)
+ deep_update(self.punc_kwargs, cfg)
punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
raw_text = copy.copy(result["text"])
if return_raw_text: result['raw_text'] = raw_text
@@ -467,3 +471,35 @@
# f"time_escape_all: {time_escape_total_all_samples:0.3f}")
return results_ret_list
+ def export(self, input=None,
+ type : str = "onnx",
+ quantize: bool = False,
+ fallback_num: int = 5,
+ calib_num: int = 100,
+ opset_version: int = 14,
+ **cfg):
+ os.environ['EXPORTING_MODEL'] = 'TRUE'
+ kwargs = self.kwargs
+ deep_update(kwargs, cfg)
+ del kwargs["model"]
+ model = self.model
+ model.eval()
+
+ batch_size = 1
+
+ key_list, data_list = prepare_data_iterator(input, input_len=None, data_type=kwargs.get("data_type", None), key=None)
+
+ with torch.no_grad():
+
+ if type == "onnx":
+ export_dir = export_utils.export_onnx(
+ model=model,
+ data_in=data_list,
+ **kwargs)
+ else:
+ export_dir = export_utils.export_torchscripts(
+ model=model,
+ data_in=data_list,
+ **kwargs)
+
+ return export_dir
\ No newline at end of file
diff --git a/funasr/bin/export.py b/funasr/bin/export.py
new file mode 100644
index 0000000..68acc17
--- /dev/null
+++ b/funasr/bin/export.py
@@ -0,0 +1,40 @@
+import os
+import hydra
+import logging
+from omegaconf import DictConfig, OmegaConf, ListConfig
+
+from funasr.auto.auto_model import AutoModel
+
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(cfg: DictConfig):
+ def to_plain_list(cfg_item):
+ if isinstance(cfg_item, ListConfig):
+ return OmegaConf.to_container(cfg_item, resolve=True)
+ elif isinstance(cfg_item, DictConfig):
+ return {k: to_plain_list(v) for k, v in cfg_item.items()}
+ else:
+ return cfg_item
+
+ kwargs = to_plain_list(cfg)
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+
+ logging.basicConfig(level=log_level)
+
+ if kwargs.get("debug", False):
+ import pdb; pdb.set_trace()
+
+
+ model = AutoModel(export_model=True, **kwargs)
+ res = model.export(input=kwargs.get("input", None),
+ type=kwargs.get("type", "onnx"),
+ quantize=kwargs.get("quantize", False),
+ fallback_num=kwargs.get("fallback-num", 5),
+ calib_num=kwargs.get("calib_num", 100),
+ opset_version=kwargs.get("opset_version", 14),
+ )
+ print(res)
+
+
+if __name__ == '__main__':
+ main_hydra()
\ No newline at end of file
diff --git a/funasr/datasets/llm_datasets/preprocessor.py b/funasr/datasets/llm_datasets/preprocessor.py
index 9f20672..4cbef33 100644
--- a/funasr/datasets/llm_datasets/preprocessor.py
+++ b/funasr/datasets/llm_datasets/preprocessor.py
@@ -18,7 +18,7 @@
@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
-class TextPreprocessSegDict(nn.Module):
+class TextPreprocessRemovePunctuation(nn.Module):
def __init__(self,
**kwargs):
super().__init__()
diff --git a/funasr/datasets/llm_datasets_qwenaudio/__init__.py b/funasr/datasets/llm_datasets_qwenaudio/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/llm_datasets_qwenaudio/__init__.py
diff --git a/funasr/datasets/llm_datasets_qwenaudio/datasets.py b/funasr/datasets/llm_datasets_qwenaudio/datasets.py
new file mode 100644
index 0000000..674217c
--- /dev/null
+++ b/funasr/datasets/llm_datasets_qwenaudio/datasets.py
@@ -0,0 +1,150 @@
+import torch
+import copy
+
+from funasr.register import tables
+from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
+
+
+@tables.register("dataset_classes", "AudioLLMQwenAudioDataset")
+class AudioLLMQwenAudioDataset(torch.utils.data.Dataset):
+ """
+ AudioLLMDataset
+ """
+
+ def __init__(self,
+ path,
+ index_ds: str = None,
+ frontend=None,
+ tokenizer=None,
+ int_pad_value: int = -1,
+ float_pad_value: float = 0.0,
+ **kwargs):
+ super().__init__()
+ index_ds_class = tables.index_ds_classes.get(index_ds)
+ self.index_ds = index_ds_class(path, **kwargs)
+ preprocessor_speech = kwargs.get("preprocessor_speech", None)
+ if preprocessor_speech:
+ preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+ preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
+ self.preprocessor_speech = preprocessor_speech
+ preprocessor_text = kwargs.get("preprocessor_text", None)
+ if preprocessor_text:
+ preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+ preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
+ self.preprocessor_text = preprocessor_text
+
+ self.frontend = frontend
+ self.fs = 16000 if frontend is None else frontend.fs
+ self.data_type = "sound"
+ self.tokenizer = tokenizer
+
+ self.float_pad_value = float_pad_value
+ self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
+ # self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(self.prompt) # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
+ self.prompt_af = ""
+ self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
+ self.int_pad_value = self.IGNORE_INDEX
+ self.audio_adaptor_downsample_rate = kwargs.get("audio_adaptor_downsample_rate", 5)
+ self.audio_encoder_downsample_rate = kwargs.get("audio_encoder_downsample_rate", 2)
+ self.prompt_template = "{}"
+ self.answer_template = "{}"
+
+
+ def get_source_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_source_len(item)
+
+ def get_target_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_target_len(item)
+
+ def __len__(self):
+ return len(self.index_ds)
+
+ def __getitem__(self, index):
+ item = self.index_ds[index]
+ # import pdb;
+ # pdb.set_trace()
+ source = item["source"]
+ data_src = load_audio_text_image_video(source, fs=self.fs)
+ if self.preprocessor_speech:
+ data_src = self.preprocessor_speech(data_src, fs=self.fs)
+ speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend,
+ is_final=True) # speech: [b, T, d]
+ speech = speech.squeeze(0)
+
+ audio_pseudo_length = (speech.shape[0] + 1) // self.audio_adaptor_downsample_rate // self.audio_encoder_downsample_rate
+ audio_pseudo = torch.full((audio_pseudo_length,), -1) # placeholder
+
+ target = item["target"]
+ if self.preprocessor_text:
+ target = self.preprocessor_text(target)
+
+ self.prompt_pre = self.prompt_template.format(self.prompt)
+ prompt_ids_pre = self.tokenizer.encode(self.prompt_pre) # [bos,prompt]
+ prompt_pre_length = len(prompt_ids_pre)
+
+ # input
+ input = self.answer_template.format(target.lower())
+ prompt_input = "{}{}".format(self.prompt_pre, input)
+ prompt_input_ids = self.tokenizer.encode(prompt_input) # [bos, prompt, input]
+ # audio_length = len(prompt_input_ids) - prompt_pre_length
+ input_ids = prompt_input_ids + [self.tokenizer.pad_token_id] # [bos, prompt, input, pad]
+ input_ids_length = len(input_ids)
+ input_ids = torch.tensor(input_ids, dtype=torch.int64) # [bos, prompt, input, pad]
+ input_ids = torch.cat((audio_pseudo, input_ids)) # [audio, bos, prompt, input, pad]
+ # input_ids[:audio_pseudo_length] = -1 # [-1, bos, prompt, input, pad]
+ attention_mask = input_ids.ge(-1) # [true, true, true, true, true], length mask
+ # input_ids[prompt_pre_length:] = -1 # [bos, prompt,-1,-1]
+ # attention_mask = input_ids.ge(-1) # [true, true, true, true], length mask
+
+ # label
+ answer = self.answer_template.format(target.lower())
+ prompt_answer = "{}{}".format(self.prompt_pre, answer)
+ prompt_answer_ids = self.tokenizer.encode(prompt_answer)
+ # answer_length = len(prompt_answer_ids) - prompt_pre_length
+ labels_ids = copy.deepcopy(prompt_answer_ids) + [self.tokenizer.eos_token_id]
+ labels_ids = torch.tensor(labels_ids, dtype=torch.int64) # [bos, prompt, answer, eos]
+ labels_ids = torch.cat((audio_pseudo, labels_ids)) # [audio, bos, prompt, answer, eos]
+ labels_ids[:audio_pseudo_length+prompt_pre_length] = -1 # [-1, -1, -1, answer, eos]
+ # labels_ids[:prompt_pre_length] = -1 # [-1, -1, input, eos]
+ label_mask = labels_ids.ge(0) # [false, false, false, true, true]
+ labels_ids[~label_mask] = self.IGNORE_INDEX # [-100, -100, -100, answer, eos]
+
+ # audio_mask for input_ids
+ audio_mask = [1]*audio_pseudo_length + [0]*input_ids_length
+ audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
+
+ ids = self.tokenizer.encode(target) # token ids is different from labels_ids
+ text = torch.tensor(ids, dtype=torch.int64)
+ text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
+
+ return {"speech": speech,
+ "speech_lengths": speech_lengths,
+ "text": text,
+ "text_lengths": text_lengths,
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels_ids": labels_ids,
+ "label_mask": label_mask,
+ "audio_mask": audio_mask,
+ }
+
+ def collator(self, samples: list = None):
+ outputs = {}
+ for sample in samples:
+ for key in sample.keys():
+ if key not in outputs:
+ outputs[key] = []
+ outputs[key].append(sample[key])
+
+ for key, data_list in outputs.items():
+ if isinstance(data_list[0], torch.Tensor):
+ if data_list[0].dtype == torch.int64:
+
+ pad_value = self.int_pad_value
+ else:
+ pad_value = self.float_pad_value
+
+ outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+ return outputs
diff --git a/funasr/datasets/llm_datasets_vicuna/preprocessor.py b/funasr/datasets/llm_datasets_vicuna/preprocessor.py
deleted file mode 100644
index 9f20672..0000000
--- a/funasr/datasets/llm_datasets_vicuna/preprocessor.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import os
-import json
-import torch
-import logging
-import concurrent.futures
-import librosa
-import torch.distributed as dist
-from typing import Collection
-import torch
-import torchaudio
-from torch import nn
-import random
-import re
-import string
-from funasr.tokenizer.cleaner import TextCleaner
-from funasr.register import tables
-
-
-
-@tables.register("preprocessor_classes", "TextPreprocessRemovePunctuation")
-class TextPreprocessSegDict(nn.Module):
- def __init__(self,
- **kwargs):
- super().__init__()
-
-
- def forward(self, text, **kwargs):
- # 瀹氫箟鑻辨枃鏍囩偣绗﹀彿
- en_punct = string.punctuation
- # 瀹氫箟涓枃鏍囩偣绗﹀彿锛堥儴鍒嗗父鐢ㄧ殑锛�
- cn_punct = '銆傦紵锛侊紝銆侊紱锛氣�溾�濃�樷�欙紙锛夈�娿�嬨�愩�戔�︹�旓綖路'
- # 鍚堝苟鑻辨枃鍜屼腑鏂囨爣鐐圭鍙�
- all_punct = en_punct + cn_punct
- # 鍒涘缓姝e垯琛ㄨ揪寮忔ā寮忥紝鍖归厤浠讳綍鍦╝ll_punct涓殑瀛楃
- punct_pattern = re.compile('[{}]'.format(re.escape(all_punct)))
- # 浣跨敤姝e垯琛ㄨ揪寮忕殑sub鏂规硶鏇挎崲鎺夎繖浜涘瓧绗�
- return punct_pattern.sub('', text)
diff --git a/funasr/models/fsmn_vad_streaming/encoder.py b/funasr/models/fsmn_vad_streaming/encoder.py
index ae91852..e7c0e8b 100755
--- a/funasr/models/fsmn_vad_streaming/encoder.py
+++ b/funasr/models/fsmn_vad_streaming/encoder.py
@@ -1,5 +1,6 @@
from typing import Tuple, Dict
import copy
+import os
import numpy as np
import torch
@@ -134,6 +135,25 @@
x3 = self.affine(x2)
x4 = self.relu(x3)
return x4
+class BasicBlock_export(nn.Module):
+ def __init__(self,
+ model,
+ ):
+ super(BasicBlock_export, self).__init__()
+ self.linear = model.linear
+ self.fsmn_block = model.fsmn_block
+ self.affine = model.affine
+ self.relu = model.relu
+
+ def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
+ x = self.linear(input) # B T D
+ # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+ # if cache_layer_name not in in_cache:
+ # in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+ x, out_cache = self.fsmn_block(x, in_cache)
+ x = self.affine(x)
+ x = self.relu(x)
+ return x, out_cache
class FsmnStack(nn.Sequential):
@@ -192,6 +212,12 @@
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
self.softmax = nn.Softmax(dim=-1)
+
+ # export onnx or torchscripts
+ if "EXPORTING_MODEL" in os.environ and os.environ['EXPORTING_MODEL'] == 'TRUE':
+ for i, d in enumerate(self.fsmn):
+ if isinstance(d, BasicBlock):
+ self.fsmn[i] = BasicBlock_export(d)
def fuse_modules(self):
pass
@@ -218,6 +244,32 @@
return x7
+ def export_forward(
+ self,
+ input: torch.Tensor,
+ *args,
+ ):
+ """
+ Args:
+ input (torch.Tensor): Input tensor (B, T, D)
+ in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
+ {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
+ """
+
+ x = self.in_linear1(input)
+ x = self.in_linear2(x)
+ x = self.relu(x)
+ # x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
+ out_caches = list()
+ for i, d in enumerate(self.fsmn):
+ in_cache = args[i]
+ x, out_cache = d(x, in_cache)
+ out_caches.append(out_cache)
+ x = self.out_linear1(x)
+ x = self.out_linear2(x)
+ x = self.softmax(x)
+
+ return x, out_caches
'''
one deep fsmn layer
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index e29af57..9649ee8 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -284,6 +284,7 @@
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
self.encoder = encoder
+ self.encoder_conf = encoder_conf
def ResetDetection(self, cache: dict = {}):
cache["stats"].continous_silence_frame_count = 0
@@ -642,6 +643,45 @@
return results, meta_data
+ def export(self, **kwargs):
+ self.forward = self._export_forward
+
+ return self
+
+ def _export_forward(self, feats: torch.Tensor, *args, **kwargs):
+ scores, out_caches = self.encoder.export_forward(feats, *args)
+ return scores, out_caches
+
+ def export_dummy_inputs(self, data_in=None, frame=30):
+ if data_in is None:
+ speech = torch.randn(1, frame, self.encoder_conf.get("input_dim"))
+ else:
+ speech = None # Undo
+
+ cache_frames = self.encoder_conf.get("lorder") + self.encoder_conf.get("rorder") - 1
+ in_cache0 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+ in_cache1 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+ in_cache2 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+ in_cache3 = torch.randn(1, self.encoder_conf.get("proj_dim"), cache_frames, 1)
+
+ return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
+
+ def export_input_names(self):
+ return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
+
+ def export_output_names(self):
+ return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
+
+ def export_dynamic_axes(self):
+ return {
+ 'speech': {
+ 1: 'feats_length'
+ },
+ }
+
+ def export_name(self, ):
+ return "model.onnx"
+
def DetectCommonFrames(self, cache: dict = {}) -> int:
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py
index 0f4973d..d83f571 100644
--- a/funasr/models/llm_asr_nar/model.py
+++ b/funasr/models/llm_asr_nar/model.py
@@ -218,6 +218,9 @@
audio_mask = kwargs.get("audio_mask", None)
audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
+ text_token_int = kwargs.get("text_token_int", None)
+ if audio_token_lengths is None:
+ audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
batch = {"speech": speech, "speech_lengths": speech_lengths}
enc, enc_lens = self.audio_encoder.encode(**batch)
@@ -259,7 +262,13 @@
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer)
+ tokenizer=None)
+ if len(kwargs.get("data_type")) > 1:
+ audio_sample_list, text_token_int_list = audio_sample_list
+ text_token_int = text_token_int_list[0].replace(" ", "")
+ text_token_int = tokenizer.encode(text_token_int)
+ else:
+ text_token_int = None
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
@@ -272,7 +281,7 @@
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text_token_int=text_token_int)
# adaptor
encoder_out = self.adaptor(encoder_out)
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 4d9f5d8..ddcfb5a 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -155,7 +155,7 @@
tf2torch_tensor_name_prefix_tf="seq2seq/cif",
tail_mask=True,
):
- super(CifPredictorV2, self).__init__()
+ super().__init__()
self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1)
@@ -375,6 +375,120 @@
predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
return predictor_alignments.detach(), predictor_alignments_length.detach()
+@tables.register("predictor_classes", "CifPredictorV2Export")
+class CifPredictorV2(torch.nn.Module):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+
+ self.pad = model.pad
+ self.cif_conv1d = model.cif_conv1d
+ self.cif_output = model.cif_output
+ self.threshold = model.threshold
+ self.smooth_factor = model.smooth_factor
+ self.noise_threshold = model.noise_threshold
+ self.tail_threshold = model.tail_threshold
+
+ def forward(self, hidden: torch.Tensor,
+ mask: torch.Tensor,
+ ):
+ alphas, token_num = self.forward_cnn(hidden, mask)
+ mask = mask.transpose(-1, -2).float()
+ mask = mask.squeeze(-1)
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
+ acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold)
+
+ return acoustic_embeds, token_num, alphas, cif_peak
+
+ def forward_cnn(self, hidden: torch.Tensor,
+ mask: torch.Tensor,
+ ):
+ h = hidden
+ context = h.transpose(1, 2)
+ queries = self.pad(context)
+ output = torch.relu(self.cif_conv1d(queries))
+ output = output.transpose(1, 2)
+
+ output = self.cif_output(output)
+ alphas = torch.sigmoid(output)
+ alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+ mask = mask.transpose(-1, -2).float()
+ alphas = alphas * mask
+ alphas = alphas.squeeze(-1)
+ token_num = alphas.sum(-1)
+
+ return alphas, token_num
+
+ def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
+ b, t, d = hidden.size()
+ tail_threshold = self.tail_threshold
+
+ zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
+ ones_t = torch.ones_like(zeros_t)
+
+ mask_1 = torch.cat([mask, zeros_t], dim=1)
+ mask_2 = torch.cat([ones_t, mask], dim=1)
+ mask = mask_2 - mask_1
+ tail_threshold = mask * tail_threshold
+ alphas = torch.cat([alphas, zeros_t], dim=1)
+ alphas = torch.add(alphas, tail_threshold)
+
+ zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
+ hidden = torch.cat([hidden, zeros], dim=1)
+ token_num = alphas.sum(dim=-1)
+ token_num_floor = torch.floor(token_num)
+
+ return hidden, alphas, token_num_floor
+
+@torch.jit.script
+def cif_export(hidden, alphas, threshold: float):
+ batch_size, len_time, hidden_size = hidden.size()
+ threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
+
+ # loop varss
+ integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
+ frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
+ # intermediate vars along time
+ list_fires = []
+ list_frames = []
+
+ for t in range(len_time):
+ alpha = alphas[:, t]
+ distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
+
+ integrate += alpha
+ list_fires.append(integrate)
+
+ fire_place = integrate >= threshold
+ integrate = torch.where(fire_place,
+ integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
+ integrate)
+ cur = torch.where(fire_place,
+ distribution_completion,
+ alpha)
+ remainds = alpha - cur
+
+ frame += cur[:, None] * hidden[:, t, :]
+ list_frames.append(frame)
+ frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
+ remainds[:, None] * hidden[:, t, :],
+ frame)
+
+ fires = torch.stack(list_fires, 1)
+ frames = torch.stack(list_frames, 1)
+
+ fire_idxs = fires >= threshold
+ frame_fires = torch.zeros_like(hidden)
+ max_label_len = frames[0, fire_idxs[0]].size(0)
+ for b in range(batch_size):
+ frame_fire = frames[b, fire_idxs[b]]
+ frame_len = frame_fire.size(0)
+ frame_fires[b, :frame_len, :] = frame_fire
+
+ if frame_len >= max_label_len:
+ max_label_len = frame_len
+ frame_fires = frame_fires[:, :max_label_len, :]
+ return frame_fires, fires
+
class mae_loss(torch.nn.Module):
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index ad321e4..ce018f4 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -581,6 +581,177 @@
return y, new_cache
+class DecoderLayerSANMExport(torch.nn.Module):
+
+ def __init__(
+ self,
+ model
+ ):
+ super().__init__()
+ self.self_attn = model.self_attn
+ self.src_attn = model.src_attn
+ self.feed_forward = model.feed_forward
+ self.norm1 = model.norm1
+ self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
+ self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
+ self.size = model.size
+
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+
+ residual = tgt
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn is not None:
+ tgt = self.norm2(tgt)
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + x
+
+ if self.src_attn is not None:
+ residual = x
+ x = self.norm3(x)
+ x = residual + self.src_attn(x, memory, memory_mask)
+
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+
+@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
+class ParaformerSANMDecoderExport(torch.nn.Module):
+ def __init__(self, model,
+ max_seq_len=512,
+ model_name='decoder',
+ onnx: bool = True, ):
+ super().__init__()
+ # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+ from funasr.utils.torch_function import MakePadMask
+ from funasr.utils.torch_function import sequence_mask
+
+ self.model = model
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ from funasr.models.sanm.multihead_att import MultiHeadedAttentionSANMDecoderExport
+ from funasr.models.sanm.multihead_att import MultiHeadedAttentionCrossAttExport
+
+ for i, d in enumerate(self.model.decoders):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+ d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
+ self.model.decoders[i] = DecoderLayerSANMExport(d)
+
+ if self.model.decoders2 is not None:
+ for i, d in enumerate(self.model.decoders2):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ self.model.decoders2[i] = DecoderLayerSANMExport(d)
+
+ for i, d in enumerate(self.model.decoders3):
+ self.model.decoders3[i] = DecoderLayerSANMExport(d)
+
+ self.output_layer = model.output_layer
+ self.after_norm = model.after_norm
+ self.model_name = model_name
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = self.make_pad_mask(ys_in_lens)
+ tgt_mask, _ = self.prepare_mask(tgt_mask)
+ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = self.make_pad_mask(hlens)
+ _, memory_mask = self.prepare_mask(memory_mask)
+ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.model.decoders2 is not None:
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
+ x, tgt_mask, memory, memory_mask
+ )
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ x = self.after_norm(x)
+ x = self.output_layer(x)
+
+ return x, ys_in_lens
+
+ def get_dummy_inputs(self, enc_size):
+ tgt = torch.LongTensor([0]).unsqueeze(0)
+ memory = torch.randn(1, 100, enc_size)
+ pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ cache = [
+ torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+ for _ in range(cache_num)
+ ]
+ return (tgt, memory, pre_acoustic_embeds, cache)
+
+ def is_optimizable(self):
+ return True
+
+ def get_input_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+ + ['cache_%d' % i for i in range(cache_num)]
+
+ def get_output_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['y'] \
+ + ['out_cache_%d' % i for i in range(cache_num)]
+
+ def get_dynamic_axes(self):
+ ret = {
+ 'tgt': {
+ 0: 'tgt_batch',
+ 1: 'tgt_length'
+ },
+ 'memory': {
+ 0: 'memory_batch',
+ 1: 'memory_length'
+ },
+ 'pre_acoustic_embeds': {
+ 0: 'acoustic_embeds_batch',
+ 1: 'acoustic_embeds_length',
+ }
+ }
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ ret.update({
+ 'cache_%d' % d: {
+ 0: 'cache_%d_batch' % d,
+ 2: 'cache_%d_length' % d
+ }
+ for d in range(cache_num)
+ })
+ return ret
+
+
@tables.register("decoder_classes", "ParaformerSANDecoder")
class ParaformerSANDecoder(BaseTransformerDecoder):
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 51c9bdb..8739ed5 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -549,3 +549,80 @@
return results, meta_data
+ def export(
+ self,
+ max_seq_len=512,
+ **kwargs,
+ ):
+ onnx = kwargs.get("onnx", True)
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
+ self.encoder = encoder_class(self.encoder, onnx=onnx)
+
+ predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
+ self.predictor = predictor_class(self.predictor, onnx=onnx)
+
+
+ decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
+ self.decoder = decoder_class(self.decoder, onnx=onnx)
+
+ from funasr.utils.torch_function import MakePadMask
+ from funasr.utils.torch_function import sequence_mask
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ self.forward = self._export_forward
+
+ return self
+
+ def _export_forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ):
+ # a. To device
+ batch = {"speech": speech, "speech_lengths": speech_lengths}
+ # batch = to_device(batch, device=self.device)
+
+ enc, enc_len = self.encoder(**batch)
+ mask = self.make_pad_mask(enc_len)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+ pre_token_length = pre_token_length.floor().type(torch.int32)
+
+ decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
+ # sample_ids = decoder_out.argmax(dim=-1)
+
+ return decoder_out, pre_token_length
+
+ def export_dummy_inputs(self):
+ speech = torch.randn(2, 30, 560)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ return (speech, speech_lengths)
+
+
+ def export_input_names(self):
+ return ['speech', 'speech_lengths']
+
+ def export_output_names(self):
+ return ['logits', 'token_num']
+
+ def export_dynamic_axes(self):
+ return {
+ 'speech': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'speech_lengths': {
+ 0: 'batch_size',
+ },
+ 'logits': {
+ 0: 'batch_size',
+ 1: 'logits_length'
+ },
+ }
+
+ def export_name(self, ):
+ return "model.onnx"
diff --git a/funasr/models/qwen_audio/audio.py b/funasr/models/qwen_audio/audio.py
new file mode 100644
index 0000000..1a37897
--- /dev/null
+++ b/funasr/models/qwen_audio/audio.py
@@ -0,0 +1,446 @@
+import base64
+import gzip
+from dataclasses import dataclass
+from typing import Dict, Iterable, Optional, List
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from subprocess import CalledProcessError, run, Popen, PIPE
+
+import os
+from functools import lru_cache
+from typing import Optional, Union
+from funasr.register import tables
+
+
+
+def exact_div(x, y):
+ assert x % y == 0
+ return x // y
+
+# hard-coded audio hyperparameters
+SAMPLE_RATE = 16000
+N_FFT = 400
+N_MELS = 80
+HOP_LENGTH = 160
+CHUNK_LENGTH = 30
+N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
+N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
+
+N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
+FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
+TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
+
+
+
+def get_T_after_cnn(L_in, dilation=1):
+ for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
+ L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
+ L_out = 1 + L_out // stride
+ L_in = L_out
+ return L_out
+
+def load_bytesio_audio(content, sr: int = SAMPLE_RATE):
+ cmd = [
+ "ffmpeg",
+ "-nostdin",
+ "-threads", "0",
+ "-i", "pipe:",
+ "-f", "s16le",
+ "-ac", "1",
+ "-acodec", "pcm_s16le",
+ "-ar", str(sr),
+ "pipe:"
+ ]
+ p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, bufsize=-1)
+ out, _ = p.communicate(input=content)
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
+
+def load_audio(file: str, sr: int = SAMPLE_RATE):
+ """
+ Open an audio file and read as mono waveform, resampling as necessary
+
+ Parameters
+ ----------
+ file: str
+ The audio file to open
+
+ sr: int
+ The sample rate to resample the audio if necessary
+
+ Returns
+ -------
+ A NumPy array containing the audio waveform, in float32 dtype.
+ """
+
+ # This launches a subprocess to decode audio while down-mixing
+ # and resampling as necessary. Requires the ffmpeg CLI in PATH.
+ # fmt: off
+ cmd = [
+ "ffmpeg",
+ "-nostdin",
+ "-threads", "0",
+ "-i", file,
+ "-f", "s16le",
+ "-ac", "1",
+ "-acodec", "pcm_s16le",
+ "-ar", str(sr),
+ "-"
+ ]
+ # fmt: on
+ try:
+ out = run(cmd, capture_output=True, check=True).stdout
+ except CalledProcessError as e:
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
+
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
+
+
+def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
+ """
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
+ """
+ if torch.is_tensor(array):
+ if array.shape[axis] > length:
+ array = array.index_select(
+ dim=axis, index=torch.arange(length, device=array.device)
+ )
+
+ if array.shape[axis] < length:
+ pad_widths = [(0, 0)] * array.ndim
+ pad_widths[axis] = (0, length - array.shape[axis])
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
+ else:
+ if array.shape[axis] > length:
+ array = array.take(indices=range(length), axis=axis)
+
+ if array.shape[axis] < length:
+ pad_widths = [(0, 0)] * array.ndim
+ pad_widths[axis] = (0, length - array.shape[axis])
+ array = np.pad(array, pad_widths)
+
+ return array
+
+def trim(array, length: int = N_SAMPLES, *, axis: int = -1):
+ """
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
+ """
+ if torch.is_tensor(array):
+ if array.shape[axis] > length:
+ array = array.index_select(
+ dim=axis, index=torch.arange(length, device=array.device)
+ )
+ else:
+ if array.shape[axis] > length:
+ array = array.take(indices=range(length), axis=axis)
+ return array
+
+
+@lru_cache(maxsize=None)
+def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
+ """
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
+ Allows decoupling librosa dependency; saved using:
+
+ np.savez_compressed(
+ "mel_filters.npz",
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
+ )
+ """
+ assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
+ with np.load(
+ os.path.join(os.path.dirname(__file__), "mel_filters.npz") # todo
+ # os.path.join("assets", "mel_filters.npz")
+ ) as f:
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
+
+
+def log_mel_spectrogram(
+ audio: Union[str, np.ndarray, torch.Tensor],
+ n_mels: int = N_MELS,
+ padding: int = 0,
+ device: Optional[Union[str, torch.device]] = None,
+):
+ """
+ Compute the log-Mel spectrogram of
+
+ Parameters
+ ----------
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
+
+ n_mels: int
+ The number of Mel-frequency filters, only 80 is supported
+
+ padding: int
+ Number of zero samples to pad to the right
+
+ device: Optional[Union[str, torch.device]]
+ If given, the audio tensor is moved to this device before STFT
+
+ Returns
+ -------
+ torch.Tensor, shape = (80, n_frames)
+ A Tensor that contains the Mel spectrogram
+ """
+ if not torch.is_tensor(audio):
+ if isinstance(audio, str):
+ audio = load_audio(audio)
+ audio = torch.from_numpy(audio)
+
+ if device is not None:
+ audio = audio.to(device)
+ if padding > 0:
+ audio = F.pad(audio, (0, padding))
+ window = torch.hann_window(N_FFT).to(audio.device)
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
+ magnitudes = stft[..., :-1].abs() ** 2
+
+ filters = mel_filters(audio.device, n_mels)
+ mel_spec = filters @ magnitudes
+
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
+ log_spec = (log_spec + 4.0) / 4.0
+ return log_spec
+
+
+@dataclass
+class ModelDimensions:
+ n_mels: int
+ n_audio_ctx: int
+ n_audio_state: int
+ n_audio_head: int
+ n_audio_layer: int
+ n_vocab: int
+ n_text_ctx: int
+ n_text_state: int
+ n_text_head: int
+ n_text_layer: int
+
+
+class LayerNorm(nn.LayerNorm):
+ def forward(self, x: Tensor) -> Tensor:
+ # return super().forward(x.float()).type(x.dtype)
+ return super().forward(x).type(x.dtype)
+
+
+
+
+class Linear(nn.Linear):
+ def forward(self, x: Tensor) -> Tensor:
+ return F.linear(
+ x,
+ self.weight.to(x.dtype),
+ None if self.bias is None else self.bias.to(x.dtype),
+ )
+
+
+class Conv1d(nn.Conv1d):
+ def _conv_forward(
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
+ ) -> Tensor:
+ return super()._conv_forward(
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
+ )
+
+
+def sinusoids(length, channels, max_timescale=10000):
+ """Returns sinusoids for positional embedding"""
+ assert channels % 2 == 0
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, n_state: int, n_head: int):
+ super().__init__()
+ self.n_head = n_head
+ self.query = Linear(n_state, n_state)
+ self.key = Linear(n_state, n_state, bias=False)
+ self.value = Linear(n_state, n_state)
+ self.out = Linear(n_state, n_state)
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ ):
+ q = self.query(x)
+
+ if kv_cache is None or xa is None or self.key not in kv_cache:
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
+ k = self.key(x if xa is None else xa)
+ v = self.value(x if xa is None else xa)
+ else:
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
+ k = kv_cache[self.key]
+ v = kv_cache[self.value]
+
+ wv, qk = self.qkv_attention(q, k, v, mask)
+ return self.out(wv), qk
+
+ def qkv_attention(
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
+ ):
+ n_batch, n_ctx, n_state = q.shape
+ scale = (n_state // self.n_head) ** -0.25
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
+
+ qk = q @ k
+ if mask is not None:
+ qk += mask
+
+ w = F.softmax(qk, dim=-1).to(q.dtype)
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
+ super().__init__()
+
+ self.attn = MultiHeadAttention(n_state, n_head)
+ self.attn_ln = LayerNorm(n_state)
+
+ self.cross_attn = (
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
+ )
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+
+ n_mlp = n_state * 4
+ self.mlp = nn.Sequential(
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
+ )
+ self.mlp_ln = LayerNorm(n_state)
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ kv_cache: Optional[dict] = None,
+ ):
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
+ if self.cross_attn:
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
+ x = x + self.mlp(self.mlp_ln(x))
+ return x
+
+@tables.register("encoder_classes", "QwenAudioEncoder")
+class QwenAudioEncoder(nn.Module):
+ def __init__(
+ self,
+ n_mels: int,
+ n_ctx: int,
+ n_state: int,
+ n_head: int,
+ n_layer: int,
+ output_dim: int = 512,
+ avg_pool: bool = True,
+ add_audio_bos_eos_token: bool = True,
+ **kwargs
+ ):
+ super().__init__()
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
+
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
+ )
+ self.ln_post = LayerNorm(n_state)
+
+ if avg_pool:
+ self.avg_pooler = nn.AvgPool1d(2, stride=2)
+ else:
+ self.avg_pooler = None
+ self.proj = nn.Linear(n_state, output_dim)
+ if add_audio_bos_eos_token:
+ self.audio_bos_eos_token = nn.Embedding(2, output_dim)
+ else:
+ self.audio_bos_eos_token = None
+ self.output_dim = output_dim
+ self.n_head = n_head
+
+ def forward(self, x: Tensor, padding_mask: Tensor=None, audio_lengths: Tensor=None):
+ """
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
+ the mel spectrogram of the audio
+ """
+ x = x.to(dtype=self.conv1.weight.dtype,
+ device=self.conv1.weight.device)
+ if audio_lengths is not None:
+ input_mel_len = audio_lengths[:,0] * 2
+ max_mel_len_in_batch = input_mel_len.max()
+ x = x[:, :, :max_mel_len_in_batch]
+ x = F.gelu(self.conv1(x))
+ x = F.gelu(self.conv2(x))
+ x = x.permute(0, 2, 1) # B, L, D
+ bsz = x.size(0)
+ src_len = x.size(1)
+
+
+ self.input_positional_embedding = self.positional_embedding[:src_len]
+ assert x.shape[1:] == self.input_positional_embedding.shape, f"incorrect audio shape: {x.shape[1:], self.input_positional_embedding.shape}"
+ x = (x + self.input_positional_embedding).to(x.dtype)
+ if padding_mask is not None:
+ padding_mask = padding_mask.to(dtype=self.conv1.weight.dtype,
+ device=self.conv1.weight.device)
+ batch_src_len = padding_mask.size(1)
+ x = x[:, :batch_src_len, :]
+ padding_mask = padding_mask.view(
+ bsz, -1, batch_src_len
+ )
+ padding_mask_ = padding_mask.all(1)
+ x[padding_mask_] = 0
+ key_padding_mask = padding_mask_.view(bsz, 1, 1, batch_src_len). \
+ expand(-1, self.n_head, -1, -1).reshape(bsz, self.n_head, 1, batch_src_len)
+ new_padding_mask = torch.zeros_like(key_padding_mask, dtype=x.dtype)
+ padding_mask = new_padding_mask.masked_fill(key_padding_mask, float("-inf"))
+
+ for block in self.blocks:
+ x = block(x, mask=padding_mask)
+
+
+ if self.avg_pooler:
+ x = x.permute(0, 2, 1)
+ x = self.avg_pooler(x)
+ x = x.permute(0, 2, 1)
+
+
+ x = self.ln_post(x)
+ x = self.proj(x)
+
+ if self.audio_bos_eos_token is not None:
+ bos = self.audio_bos_eos_token.weight[0][None, :]
+ eos = self.audio_bos_eos_token.weight[1][None, :]
+ else:
+ bos, eos = None, None
+ return x, bos, eos
+
+ def encode(self, input_audios: Tensor, input_audio_lengths: Tensor, audio_span_tokens: List):
+ real_input_audio_lens = input_audio_lengths[:, 0].tolist()
+ max_len_in_batch = max(real_input_audio_lens)
+ padding_mask = torch.ones([input_audios.size(0), max_len_in_batch]).to(dtype=self.conv1.weight.dtype,
+ device=self.conv1.weight.device)
+ for index in range(len(input_audios)):
+ padding_mask[index, :input_audio_lengths[index][0].item()] = 0
+ x, bos, eos = self(input_audios, padding_mask,input_audio_lengths)
+ output_audios = []
+ for i in range(len(audio_span_tokens)):
+ audio_span = audio_span_tokens[i]
+ audio = x[i][:audio_span-2]
+ if bos is not None:
+ audio = torch.concat([bos, audio, eos])
+ assert len(audio) == audio_span
+ output_audios.append(audio)
+ return output_audios
diff --git a/funasr/models/qwen_audio/model.py b/funasr/models/qwen_audio/model.py
index 86b4105..045cfe1 100644
--- a/funasr/models/qwen_audio/model.py
+++ b/funasr/models/qwen_audio/model.py
@@ -60,7 +60,7 @@
inputs = self.tokenizer(query, return_tensors='pt', audio_info=audio_info)
inputs = inputs.to(self.model.device)
pred = self.model.generate(**inputs, audio_info=audio_info)
- response = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False, audio_info=audio_info)
+ response = self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=False, audio_info=audio_info)
results = []
result_i = {"key": key[0], "text": response}
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index 10f0a3b..09a1f07 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -303,6 +303,64 @@
att_outs = self.forward_attention(v_h, scores, None)
return att_outs + fsmn_memory, cache
+class MultiHeadedAttentionSANMExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.d_k = model.d_k
+ self.h = model.h
+ self.linear_out = model.linear_out
+ self.linear_q_k_v = model.linear_q_k_v
+ self.fsmn_block = model.fsmn_block
+ self.pad_fn = model.pad_fn
+
+ self.attn = None
+ self.all_head_size = self.h * self.d_k
+
+ def forward(self, x, mask):
+ mask_3d_btd, mask_4d_bhlt = mask
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
+ q_h = q_h * self.d_k**(-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
+ return att_outs + fsmn_memory
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward_qkv(self, x):
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ q_h = self.transpose_for_scores(q)
+ k_h = self.transpose_for_scores(k)
+ v_h = self.transpose_for_scores(v)
+ return q_h, k_h, v_h, v
+
+ def forward_fsmn(self, inputs, mask):
+ # b, t, d = inputs.size()
+ # mask = torch.reshape(mask, (b, -1, 1))
+ inputs = inputs * mask
+ x = inputs.transpose(1, 2)
+ x = self.pad_fn(x)
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+ x = x + inputs
+ x = x * mask
+ return x
+
+ def forward_attention(self, value, scores, mask):
+ scores = scores + mask
+
+ self.attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.linear_out(context_layer) # (batch, time1, d_model)
+
class MultiHeadedAttentionSANMDecoder(nn.Module):
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index 069c527..561179b 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -456,3 +456,135 @@
return (xs_pad, intermediate_outs), None, None
return xs_pad, ilens, None
+class EncoderLayerSANMExport(nn.Module):
+ def __init__(
+ self,
+ model,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = model.self_attn
+ self.feed_forward = model.feed_forward
+ self.norm1 = model.norm1
+ self.norm2 = model.norm2
+ self.in_size = model.in_size
+ self.size = model.size
+
+ def forward(self, x, mask):
+
+ residual = x
+ x = self.norm1(x)
+ x = self.self_attn(x, mask)
+ if self.in_size == self.size:
+ x = x + residual
+ residual = x
+ x = self.norm2(x)
+ x = self.feed_forward(x)
+ x = x + residual
+
+ return x, mask
+
+
+@tables.register("encoder_classes", "SANMEncoderExport")
+class SANMEncoderExport(nn.Module):
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ model_name='encoder',
+ onnx: bool = True,
+ ):
+ super().__init__()
+ self.embed = model.embed
+ if isinstance(self.embed, StreamSinusoidalPositionEncoder):
+ self.embed = None
+ self.model = model
+ self.feats_dim = feats_dim
+ self._output_size = model._output_size
+
+ from funasr.utils.torch_function import MakePadMask
+ from funasr.utils.torch_function import sequence_mask
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport
+ if hasattr(model, 'encoders0'):
+ for i, d in enumerate(self.model.encoders0):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+ d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+ self.model.encoders0[i] = EncoderLayerSANMExport(d)
+
+ for i, d in enumerate(self.model.encoders):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+ d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+ self.model.encoders[i] = EncoderLayerSANMExport(d)
+
+ self.model_name = model_name
+ self.num_heads = model.encoders[0].self_attn.h
+ self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ online: bool = False
+ ):
+ if not online:
+ speech = speech * self._output_size ** 0.5
+ mask = self.make_pad_mask(speech_lengths)
+ mask = self.prepare_mask(mask)
+ if self.embed is None:
+ xs_pad = speech
+ else:
+ xs_pad = self.embed(speech)
+
+ encoder_outs = self.model.encoders0(xs_pad, mask)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ encoder_outs = self.model.encoders(xs_pad, mask)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+ xs_pad = self.model.after_norm(xs_pad)
+
+ return xs_pad, speech_lengths
+
+ def get_output_size(self):
+ return self.model.encoders[0].size
+
+ def get_dummy_inputs(self):
+ feats = torch.randn(1, 100, self.feats_dim)
+ return (feats)
+
+ def get_input_names(self):
+ return ['feats']
+
+ def get_output_names(self):
+ return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
+
+ def get_dynamic_axes(self):
+ return {
+ 'feats': {
+ 1: 'feats_length'
+ },
+ 'encoder_out': {
+ 1: 'enc_out_length'
+ },
+ 'predictor_weight': {
+ 1: 'pre_out_length'
+ }
+
+ }
+
diff --git a/funasr/models/sanm/multihead_att.py b/funasr/models/sanm/multihead_att.py
new file mode 100644
index 0000000..5ef36ed
--- /dev/null
+++ b/funasr/models/sanm/multihead_att.py
@@ -0,0 +1,243 @@
+import os
+import math
+
+import torch
+import torch.nn as nn
+
+
+class MultiHeadedAttentionSANMExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.d_k = model.d_k
+ self.h = model.h
+ self.linear_out = model.linear_out
+ self.linear_q_k_v = model.linear_q_k_v
+ self.fsmn_block = model.fsmn_block
+ self.pad_fn = model.pad_fn
+
+ self.attn = None
+ self.all_head_size = self.h * self.d_k
+
+ def forward(self, x, mask):
+ mask_3d_btd, mask_4d_bhlt = mask
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
+ q_h = q_h * self.d_k**(-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
+ return att_outs + fsmn_memory
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward_qkv(self, x):
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ q_h = self.transpose_for_scores(q)
+ k_h = self.transpose_for_scores(k)
+ v_h = self.transpose_for_scores(v)
+ return q_h, k_h, v_h, v
+
+ def forward_fsmn(self, inputs, mask):
+ # b, t, d = inputs.size()
+ # mask = torch.reshape(mask, (b, -1, 1))
+ inputs = inputs * mask
+ x = inputs.transpose(1, 2)
+ x = self.pad_fn(x)
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+ x = x + inputs
+ x = x * mask
+ return x
+
+ def forward_attention(self, value, scores, mask):
+ scores = scores + mask
+
+ self.attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.linear_out(context_layer) # (batch, time1, d_model)
+
+
+def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
+ x = x * mask
+ x = x.transpose(1, 2)
+ if cache is None:
+ x = pad_fn(x)
+ else:
+ x = torch.cat((cache, x), dim=2)
+ cache = x[:, :, -(kernel_size-1):]
+ return x, cache
+
+
+torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
+if torch_version >= (1, 8):
+ import torch.fx
+ torch.fx.wrap('preprocess_for_attn')
+
+
+class MultiHeadedAttentionSANMDecoderExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.fsmn_block = model.fsmn_block
+ self.pad_fn = model.pad_fn
+ self.kernel_size = model.kernel_size
+ self.attn = None
+
+ def forward(self, inputs, mask, cache=None):
+ x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
+ x = self.fsmn_block(x)
+ x = x.transpose(1, 2)
+
+ x = x + inputs
+ x = x * mask
+ return x, cache
+
+
+class MultiHeadedAttentionCrossAttExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.d_k = model.d_k
+ self.h = model.h
+ self.linear_q = model.linear_q
+ self.linear_k_v = model.linear_k_v
+ self.linear_out = model.linear_out
+ self.attn = None
+ self.all_head_size = self.h * self.d_k
+
+ def forward(self, x, memory, memory_mask):
+ q, k, v = self.forward_qkv(x, memory)
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, memory_mask)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward_qkv(self, x, memory):
+ q = self.linear_q(x)
+
+ k_v = self.linear_k_v(memory)
+ k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
+ q = self.transpose_for_scores(q)
+ k = self.transpose_for_scores(k)
+ v = self.transpose_for_scores(v)
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ scores = scores + mask
+
+ self.attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.linear_out(context_layer) # (batch, time1, d_model)
+
+
+class OnnxMultiHeadedAttention(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.d_k = model.d_k
+ self.h = model.h
+ self.linear_q = model.linear_q
+ self.linear_k = model.linear_k
+ self.linear_v = model.linear_v
+ self.linear_out = model.linear_out
+ self.attn = None
+ self.all_head_size = self.h * self.d_k
+
+ def forward(self, query, key, value, mask):
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask)
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.h, self.d_k)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward_qkv(self, query, key, value):
+ q = self.linear_q(query)
+ k = self.linear_k(key)
+ v = self.linear_v(value)
+ q = self.transpose_for_scores(q)
+ k = self.transpose_for_scores(k)
+ v = self.transpose_for_scores(v)
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ scores = scores + mask
+
+ self.attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.linear_out(context_layer) # (batch, time1, d_model)
+
+
+class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
+ def __init__(self, model):
+ super().__init__(model)
+ self.linear_pos = model.linear_pos
+ self.pos_bias_u = model.pos_bias_u
+ self.pos_bias_v = model.pos_bias_v
+
+ def forward(self, query, key, value, pos_emb, mask):
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time1)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k
+ ) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask)
+
+ def rel_shift(self, x):
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)[
+ :, :, :, : x.size(-1) // 2 + 1
+ ] # only keep the positions from 0 to time2
+ return x
+
+ def forward_attention(self, value, scores, mask):
+ scores = scores + mask
+
+ self.attn = torch.softmax(scores, dim=-1)
+ context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ return self.linear_out(context_layer) # (batch, time1, d_model)
+
diff --git a/funasr/models/whisper/template.yaml b/funasr/models/whisper/template.yaml
index 40b902c..568b948 100644
--- a/funasr/models/whisper/template.yaml
+++ b/funasr/models/whisper/template.yaml
@@ -11,29 +11,28 @@
lsm_weight: 0.1
length_normalized_loss: true
hub: funasr # openai
- init_param_path: null # large-v2 or large-v3 if hub == "openai"
# only use for hub == funasr,
-# if hub == openai, whisper_dims is automaticall download
-whisper_dims:
- 'n_mels': 80
- 'n_vocab': 51865
- 'n_audio_ctx': 1500
- 'n_audio_state': 1280
- 'n_audio_head': 20
- 'n_audio_layer': 32
- 'n_text_ctx': 448
- 'n_text_state': 1280
- 'n_text_head': 20
- 'n_text_layer': 32
+# if hub == openai, dims is automaticall download
+dims:
+ n_mels: 128
+ n_vocab: 51866
+ n_audio_ctx: 1500
+ n_audio_state: 1280
+ n_audio_head: 20
+ n_audio_layer: 32
+ n_text_ctx: 448
+ n_text_state: 1280
+ n_text_head: 20
+ n_text_layer: 32
# frontend related
frontend: WhisperFrontend
frontend_conf:
fs: 16000
- n_mels: 80
+ n_mels: ${dims.n_mels}
do_pad_trim: true
tokenizer: WhisperTokenizer
@@ -41,6 +40,6 @@
language: null
task: transcribe
is_multilingual: true
- num_languages: 99
+ num_languages: 100
-scope_map: ['none', "model."]
+scope_map: [none, "model."]
\ No newline at end of file
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
new file mode 100644
index 0000000..640be05
--- /dev/null
+++ b/funasr/utils/export_utils.py
@@ -0,0 +1,74 @@
+import os
+import torch
+
+def export_onnx(model,
+ data_in=None,
+ type: str = "onnx",
+ quantize: bool = False,
+ fallback_num: int = 5,
+ calib_num: int = 100,
+ opset_version: int = 14,
+ **kwargs):
+ model_scripts = model.export(**kwargs)
+ export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
+ os.makedirs(export_dir, exist_ok=True)
+
+ if not isinstance(model_scripts, (list, tuple)):
+ model_scripts = (model_scripts,)
+ for m in model_scripts:
+ m.eval()
+ _onnx(m,
+ data_in=data_in,
+ type=type,
+ quantize=quantize,
+ fallback_num=fallback_num,
+ calib_num=calib_num,
+ opset_version=opset_version,
+ export_dir=export_dir,
+ **kwargs
+ )
+ print("output dir: {}".format(export_dir))
+
+ return export_dir
+
+def _onnx(model,
+ data_in=None,
+ quantize: bool = False,
+ opset_version: int = 14,
+ export_dir:str = None,
+ **kwargs):
+
+ dummy_input = model.export_dummy_inputs()
+
+ verbose = kwargs.get("verbose", False)
+
+ export_name = model.export_name() if hasattr(model, "export_name") else "model.onnx"
+ model_path = os.path.join(export_dir, export_name)
+ torch.onnx.export(
+ model,
+ dummy_input,
+ model_path,
+ verbose=verbose,
+ opset_version=opset_version,
+ input_names=model.export_input_names(),
+ output_names=model.export_output_names(),
+ dynamic_axes=model.export_dynamic_axes()
+ )
+
+ if quantize:
+ from onnxruntime.quantization import QuantType, quantize_dynamic
+ import onnx
+ quant_model_path = model_path.replace(".onnx", "_quant.onnx")
+ if not os.path.exists(quant_model_path):
+ onnx_model = onnx.load(model_path)
+ nodes = [n.name for n in onnx_model.graph.node]
+ nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m or 'bias_decoder' in m]
+ quantize_dynamic(
+ model_input=model_path,
+ model_output=quant_model_path,
+ op_types_to_quantize=['MatMul'],
+ per_channel=True,
+ reduce_range=False,
+ weight_type=QuantType.QUInt8,
+ nodes_to_exclude=nodes_to_exclude,
+ )
\ No newline at end of file
diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py
index e9050f5..a08f263 100644
--- a/funasr/utils/misc.py
+++ b/funasr/utils/misc.py
@@ -46,3 +46,10 @@
value = value.split(' ')
ret_dict.append((key, value))
return ret_dict
+
+def deep_update(original, update):
+ for key, value in update.items():
+ if isinstance(value, dict) and key in original:
+ deep_update(original[key], value)
+ else:
+ original[key] = value
\ No newline at end of file
diff --git a/funasr/utils/torch_function.py b/funasr/utils/torch_function.py
new file mode 100644
index 0000000..a078a7e
--- /dev/null
+++ b/funasr/utils/torch_function.py
@@ -0,0 +1,80 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+import numpy as np
+
+
+class MakePadMask(nn.Module):
+ def __init__(self, max_seq_len=512, flip=True):
+ super().__init__()
+ if flip:
+ self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
+ else:
+ self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)
+
+ def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
+ """Make mask tensor containing indices of padded part.
+ This implementation creates the same mask tensor with original make_pad_mask,
+ which can be converted into onnx format.
+ Dimension length of xs should be 2 or 3.
+ """
+ if length_dim == 0:
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+ if xs is not None and len(xs.shape) == 3:
+ if length_dim == 1:
+ lengths = lengths.unsqueeze(1).expand(
+ *xs.transpose(1, 2).shape[:2])
+ else:
+ lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
+
+ if maxlen is not None:
+ m = maxlen
+ elif xs is not None:
+ m = xs.shape[-1]
+ else:
+ m = torch.max(lengths)
+
+ mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
+
+ if length_dim == 1:
+ return mask.transpose(1, 2)
+ else:
+ return mask
+
+class sequence_mask(nn.Module):
+ def __init__(self, max_seq_len=512, flip=True):
+ super().__init__()
+
+ def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
+ if max_seq_len is None:
+ max_seq_len = lengths.max()
+ row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
+ matrix = torch.unsqueeze(lengths, dim=-1)
+ mask = row_vector < matrix
+
+ return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
+
+def normalize(input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if out is None:
+ denom = input.norm(p, dim, keepdim=True).expand_as(input)
+ return input / denom
+ else:
+ denom = input.norm(p, dim, keepdim=True).expand_as(input)
+ return torch.div(input, denom, out=out)
+
+def subsequent_mask(size: torch.Tensor):
+ return torch.ones(size, size).tril()
+
+
+def MakePadMask_test():
+ feats_length = torch.tensor([10]).type(torch.long)
+ mask_fn = MakePadMask()
+ mask = mask_fn(feats_length)
+ print(mask)
+
+
+if __name__ == '__main__':
+ MakePadMask_test()
\ No newline at end of file
--
Gitblit v1.9.1