From 2ae59b6ce06305724e2eaf30b9f9e93447a7832e Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 22 七月 2024 16:58:27 +0800
Subject: [PATCH] ONNX and torchscript export for sensevoice
---
funasr/models/sense_voice/export_meta.py | 58 ++----
funasr/utils/export_utils.py | 44 +++--
runtime/python/onnxruntime/funasr_onnx/__init__.py | 1
runtime/python/onnxruntime/demo_sencevoicesmall.py | 38 ++++
runtime/python/libtorch/demo_sensevoicesmall.py | 38 ++++
runtime/python/libtorch/funasr_torch/sensevoice_bin.py | 130 ++++++++++++++++
runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py | 145 ++++++++++++++++++
runtime/python/libtorch/funasr_torch/__init__.py | 1
examples/industrial_data_pretraining/sense_voice/export.py | 15 +
9 files changed, 412 insertions(+), 58 deletions(-)
diff --git a/examples/industrial_data_pretraining/sense_voice/export.py b/examples/industrial_data_pretraining/sense_voice/export.py
new file mode 100644
index 0000000..7376c8a
--- /dev/null
+++ b/examples/industrial_data_pretraining/sense_voice/export.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+
+model_dir = "iic/SenseVoiceSmall"
+model = AutoModel(
+ model=model_dir,
+ device="cuda:0",
+)
+
+res = model.export(type="onnx", quantize=False)
\ No newline at end of file
diff --git a/funasr/models/sense_voice/export_meta.py b/funasr/models/sense_voice/export_meta.py
index fe09ee1..449388e 100644
--- a/funasr/models/sense_voice/export_meta.py
+++ b/funasr/models/sense_voice/export_meta.py
@@ -5,30 +5,19 @@
import types
import torch
-import torch.nn as nn
-from funasr.register import tables
+from funasr.utils.torch_function import sequence_mask
def export_rebuild_model(model, **kwargs):
model.device = kwargs.get("device")
- is_onnx = kwargs.get("type", "onnx") == "onnx"
- # encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
- # model.encoder = encoder_class(model.encoder, onnx=is_onnx)
-
- from funasr.utils.torch_function import sequence_mask
-
model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
-
model.forward = types.MethodType(export_forward, model)
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)
-
- model.export_name = "model"
return model
-
def export_forward(
self,
@@ -38,32 +27,28 @@
textnorm: torch.Tensor,
**kwargs,
):
- speech = speech.to(device=kwargs["device"])
- speech_lengths = speech_lengths.to(device=kwargs["device"])
-
- language_query = self.embed(language).to(speech.device)
-
- textnorm_query = self.embed(textnorm).to(speech.device)
+ # speech = speech.to(device="cuda")
+ # speech_lengths = speech_lengths.to(device="cuda")
+ language_query = self.embed(language.to(speech.device)).unsqueeze(1)
+ textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1)
+
speech = torch.cat((textnorm_query, speech), dim=1)
- speech_lengths += 1
-
+
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
speech.size(0), 1, 1
)
input_query = torch.cat((language_query, event_emo_query), dim=1)
speech = torch.cat((input_query, speech), dim=1)
- speech_lengths += 3
-
- # Encoder
- encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
+
+ speech_lengths_new = speech_lengths + 4
+ encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths_new)
+
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
- # c. Passed the encoder result and the beam search
- ctc_logits = self.ctc.log_softmax(encoder_out)
-
+ ctc_logits = self.ctc.ctc_lo(encoder_out)
+
return ctc_logits, encoder_out_lens
-
def export_dummy_inputs(self):
speech = torch.randn(2, 30, 560)
@@ -72,26 +57,21 @@
textnorm = torch.tensor([15, 15], dtype=torch.int32)
return (speech, speech_lengths, language, textnorm)
-
def export_input_names(self):
return ["speech", "speech_lengths", "language", "textnorm"]
-
def export_output_names(self):
return ["ctc_logits", "encoder_out_lens"]
-
def export_dynamic_axes(self):
return {
"speech": {0: "batch_size", 1: "feats_length"},
- "speech_lengths": {
- 0: "batch_size",
- },
- "logits": {0: "batch_size", 1: "logits_length"},
+ "speech_lengths": {0: "batch_size"},
+ "language": {0: "batch_size"},
+ "textnorm": {0: "batch_size"},
+ "ctc_logits": {0: "batch_size", 1: "logits_length"},
+ "encoder_out_lens": {0: "batch_size"},
}
-
-def export_name(
- self,
-):
+def export_name(self):
return "model.onnx"
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index a6d0798..af9f37b 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -54,7 +54,10 @@
verbose = kwargs.get("verbose", False)
- export_name = model.export_name + ".onnx"
+ if isinstance(model.export_name, str):
+ export_name = model.export_name + ".onnx"
+ else:
+ export_name = model.export_name()
model_path = os.path.join(export_dir, export_name)
torch.onnx.export(
model,
@@ -72,35 +75,38 @@
import onnx
quant_model_path = model_path.replace(".onnx", "_quant.onnx")
- if not os.path.exists(quant_model_path):
- onnx_model = onnx.load(model_path)
- nodes = [n.name for n in onnx_model.graph.node]
- nodes_to_exclude = [
- m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
- ]
- quantize_dynamic(
- model_input=model_path,
- model_output=quant_model_path,
- op_types_to_quantize=["MatMul"],
- per_channel=True,
- reduce_range=False,
- weight_type=QuantType.QUInt8,
- nodes_to_exclude=nodes_to_exclude,
- )
+ onnx_model = onnx.load(model_path)
+ nodes = [n.name for n in onnx_model.graph.node]
+ nodes_to_exclude = [
+ m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
+ ]
+ print("Quantizing model from {} to {}".format(model_path, quant_model_path))
+ quantize_dynamic(
+ model_input=model_path,
+ model_output=quant_model_path,
+ op_types_to_quantize=["MatMul"],
+ per_channel=True,
+ reduce_range=False,
+ weight_type=QuantType.QUInt8,
+ nodes_to_exclude=nodes_to_exclude,
+ )
def _torchscripts(model, path, device="cuda"):
dummy_input = model.export_dummy_inputs()
-
+
if device == "cuda":
model = model.cuda()
if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.cuda()
else:
dummy_input = tuple([i.cuda() for i in dummy_input])
-
+
model_script = torch.jit.trace(model, dummy_input)
- model_script.save(os.path.join(path, f"{model.export_name}.torchscript"))
+ if isinstance(model.export_name, str):
+ model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript")))
+ else:
+ model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")))
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
diff --git a/runtime/python/libtorch/demo_sensevoicesmall.py b/runtime/python/libtorch/demo_sensevoicesmall.py
new file mode 100644
index 0000000..5c70f34
--- /dev/null
+++ b/runtime/python/libtorch/demo_sensevoicesmall.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import os
+import torch
+from pathlib import Path
+from funasr import AutoModel
+from funasr_torch import SenseVoiceSmallTorchScript as SenseVoiceSmall
+from funasr.utils.postprocess_utils import rich_transcription_postprocess
+
+
+model_dir = "iic/SenseVoiceSmall"
+model = AutoModel(
+ model=model_dir,
+ device="cuda:0",
+)
+
+# res = model.export(type="torchscript", quantize=False)
+
+# export model init
+model_path = "{}/.cache/modelscope/hub/{}".format(Path.home(), model_dir)
+model_bin = SenseVoiceSmall(model_path)
+
+# build tokenizer
+try:
+ from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
+ tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"))
+except:
+ tokenizer = None
+
+# inference
+wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav"
+language_list = [0]
+textnorm_list = [15]
+res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer)
+print([rich_transcription_postprocess(i) for i in res])
diff --git a/runtime/python/libtorch/funasr_torch/__init__.py b/runtime/python/libtorch/funasr_torch/__init__.py
index 647f9fa..4669ced 100644
--- a/runtime/python/libtorch/funasr_torch/__init__.py
+++ b/runtime/python/libtorch/funasr_torch/__init__.py
@@ -1,2 +1,3 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer
+from .sensevoice_bin import SenseVoiceSmallTorchScript
diff --git a/runtime/python/libtorch/funasr_torch/sensevoice_bin.py b/runtime/python/libtorch/funasr_torch/sensevoice_bin.py
new file mode 100644
index 0000000..d2e3cde
--- /dev/null
+++ b/runtime/python/libtorch/funasr_torch/sensevoice_bin.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+
+import torch
+import os.path
+import librosa
+import numpy as np
+from pathlib import Path
+from typing import List, Union, Tuple
+
+from .utils.utils import (
+ CharTokenizer,
+ get_logger,
+ read_yaml,
+)
+from .utils.frontend import WavFrontend
+
+logging = get_logger()
+
+
+class SenseVoiceSmallTorchScript:
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ device_id: Union[str, int] = "-1",
+ plot_timestamp_to: str = "",
+ quantize: bool = False,
+ intra_op_num_threads: int = 4,
+ cache_dir: str = None,
+ **kwargs,
+ ):
+ if quantize:
+ model_file = os.path.join(model_dir, "model_quant.torchscript")
+ else:
+ model_file = os.path.join(model_dir, "model.torchscript")
+
+ config_file = os.path.join(model_dir, "config.yaml")
+ cmvn_file = os.path.join(model_dir, "am.mvn")
+ config = read_yaml(config_file)
+ # token_list = os.path.join(model_dir, "tokens.json")
+ # with open(token_list, "r", encoding="utf-8") as f:
+ # token_list = json.load(f)
+
+ # self.converter = TokenIDConverter(token_list)
+ self.tokenizer = CharTokenizer()
+ config["frontend_conf"]['cmvn_file'] = cmvn_file
+ self.frontend = WavFrontend(**config["frontend_conf"])
+ self.ort_infer = torch.jit.load(model_file)
+ self.batch_size = batch_size
+ self.blank_id = 0
+
+ def __call__(self,
+ wav_content: Union[str, np.ndarray, List[str]],
+ language: List,
+ textnorm: List,
+ tokenizer=None,
+ **kwargs) -> List:
+ waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
+ waveform_nums = len(waveform_list)
+ asr_res = []
+ for beg_idx in range(0, waveform_nums, self.batch_size):
+ end_idx = min(waveform_nums, beg_idx + self.batch_size)
+ feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
+ ctc_logits, encoder_out_lens = self.ort_infer(torch.Tensor(feats),
+ torch.Tensor(feats_len),
+ torch.tensor(language),
+ torch.tensor(textnorm)
+ )
+ # support batch_size=1 only currently
+ x = ctc_logits[0, : encoder_out_lens[0].item(), :]
+ yseq = x.argmax(dim=-1)
+ yseq = torch.unique_consecutive(yseq, dim=-1)
+
+ mask = yseq != self.blank_id
+ token_int = yseq[mask].tolist()
+
+ if tokenizer is not None:
+ asr_res.append(tokenizer.tokens2text(token_int))
+ else:
+ asr_res.append(token_int)
+ return asr_res
+
+ def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+ def load_wav(path: str) -> np.ndarray:
+ waveform, _ = librosa.load(path, sr=fs)
+ return waveform
+
+ if isinstance(wav_content, np.ndarray):
+ return [wav_content]
+
+ if isinstance(wav_content, str):
+ return [load_wav(wav_content)]
+
+ if isinstance(wav_content, list):
+ return [load_wav(path) for path in wav_content]
+
+ raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
+
+ def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
+ feats, feats_len = [], []
+ for waveform in waveform_list:
+ speech, _ = self.frontend.fbank(waveform)
+ feat, feat_len = self.frontend.lfr_cmvn(speech)
+ feats.append(feat)
+ feats_len.append(feat_len)
+
+ feats = self.pad_feats(feats, np.max(feats_len))
+ feats_len = np.array(feats_len).astype(np.int32)
+ return feats, feats_len
+
+ @staticmethod
+ def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
+ def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
+ pad_width = ((0, max_feat_len - cur_len), (0, 0))
+ return np.pad(feat, pad_width, "constant", constant_values=0)
+
+ feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
+ feats = np.array(feat_res).astype(np.float32)
+ return feats
+
diff --git a/runtime/python/onnxruntime/demo_sencevoicesmall.py b/runtime/python/onnxruntime/demo_sencevoicesmall.py
new file mode 100644
index 0000000..27f0179
--- /dev/null
+++ b/runtime/python/onnxruntime/demo_sencevoicesmall.py
@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import os
+import torch
+from pathlib import Path
+from funasr import AutoModel
+from funasr_onnx import SenseVoiceSmallONNX as SenseVoiceSmall
+from funasr.utils.postprocess_utils import rich_transcription_postprocess
+
+
+model_dir = "iic/SenseVoiceSmall"
+model = AutoModel(
+ model=model_dir,
+ device="cuda:0",
+)
+
+res = model.export(type="onnx", quantize=False)
+
+# export model init
+model_path = "{}/.cache/modelscope/hub/{}".format(Path.home(), model_dir)
+model_bin = SenseVoiceSmall(model_path)
+
+# build tokenizer
+try:
+ from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
+ tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"))
+except:
+ tokenizer = None
+
+# inference
+wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav"
+language_list = [0]
+textnorm_list = [15]
+res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer)
+print([rich_transcription_postprocess(i) for i in res])
diff --git a/runtime/python/onnxruntime/funasr_onnx/__init__.py b/runtime/python/onnxruntime/funasr_onnx/__init__.py
index d0d6651..4256629 100644
--- a/runtime/python/onnxruntime/funasr_onnx/__init__.py
+++ b/runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -4,3 +4,4 @@
from .vad_bin import Fsmn_vad_online
from .punc_bin import CT_Transformer
from .punc_bin import CT_Transformer_VadRealtime
+from .sensevoice_bin import SenseVoiceSmallONNX
diff --git a/runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py b/runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py
new file mode 100644
index 0000000..fcfcede
--- /dev/null
+++ b/runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py
@@ -0,0 +1,145 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+
+import torch
+import os.path
+import librosa
+import numpy as np
+from pathlib import Path
+from typing import List, Union, Tuple
+
+from .utils.utils import (
+ CharTokenizer,
+ Hypothesis,
+ ONNXRuntimeError,
+ OrtInferSession,
+ TokenIDConverter,
+ get_logger,
+ read_yaml,
+)
+from .utils.frontend import WavFrontend
+
+logging = get_logger()
+
+
+class SenseVoiceSmallONNX:
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ device_id: Union[str, int] = "-1",
+ plot_timestamp_to: str = "",
+ quantize: bool = False,
+ intra_op_num_threads: int = 4,
+ cache_dir: str = None,
+ **kwargs,
+ ):
+ if quantize:
+ model_file = os.path.join(model_dir, "model_quant.onnx")
+ else:
+ model_file = os.path.join(model_dir, "model.onnx")
+
+ config_file = os.path.join(model_dir, "config.yaml")
+ cmvn_file = os.path.join(model_dir, "am.mvn")
+ config = read_yaml(config_file)
+ # token_list = os.path.join(model_dir, "tokens.json")
+ # with open(token_list, "r", encoding="utf-8") as f:
+ # token_list = json.load(f)
+
+ # self.converter = TokenIDConverter(token_list)
+ self.tokenizer = CharTokenizer()
+ config["frontend_conf"]['cmvn_file'] = cmvn_file
+ self.frontend = WavFrontend(**config["frontend_conf"])
+ self.ort_infer = OrtInferSession(
+ model_file, device_id, intra_op_num_threads=intra_op_num_threads
+ )
+ self.batch_size = batch_size
+ self.blank_id = 0
+
+ def __call__(self,
+ wav_content: Union[str, np.ndarray, List[str]],
+ language: List,
+ textnorm: List,
+ tokenizer=None,
+ **kwargs) -> List:
+ waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
+ waveform_nums = len(waveform_list)
+ asr_res = []
+ for beg_idx in range(0, waveform_nums, self.batch_size):
+ end_idx = min(waveform_nums, beg_idx + self.batch_size)
+ feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
+ ctc_logits, encoder_out_lens = self.infer(feats,
+ feats_len,
+ np.array(language, dtype=np.int32),
+ np.array(textnorm, dtype=np.int32)
+ )
+ # back to torch.Tensor
+ ctc_logits = torch.from_numpy(ctc_logits).float()
+ # support batch_size=1 only currently
+ x = ctc_logits[0, : encoder_out_lens[0].item(), :]
+ yseq = x.argmax(dim=-1)
+ yseq = torch.unique_consecutive(yseq, dim=-1)
+
+ mask = yseq != self.blank_id
+ token_int = yseq[mask].tolist()
+
+ if tokenizer is not None:
+ asr_res.append(tokenizer.tokens2text(token_int))
+ else:
+ asr_res.append(token_int)
+ return asr_res
+
+ def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+ def load_wav(path: str) -> np.ndarray:
+ waveform, _ = librosa.load(path, sr=fs)
+ return waveform
+
+ if isinstance(wav_content, np.ndarray):
+ return [wav_content]
+
+ if isinstance(wav_content, str):
+ return [load_wav(wav_content)]
+
+ if isinstance(wav_content, list):
+ return [load_wav(path) for path in wav_content]
+
+ raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
+
+ def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
+ feats, feats_len = [], []
+ for waveform in waveform_list:
+ speech, _ = self.frontend.fbank(waveform)
+ feat, feat_len = self.frontend.lfr_cmvn(speech)
+ feats.append(feat)
+ feats_len.append(feat_len)
+
+ feats = self.pad_feats(feats, np.max(feats_len))
+ feats_len = np.array(feats_len).astype(np.int32)
+ return feats, feats_len
+
+ @staticmethod
+ def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
+ def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
+ pad_width = ((0, max_feat_len - cur_len), (0, 0))
+ return np.pad(feat, pad_width, "constant", constant_values=0)
+
+ feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
+ feats = np.array(feat_res).astype(np.float32)
+ return feats
+
+ def infer(self,
+ feats: np.ndarray,
+ feats_len: np.ndarray,
+ language: np.ndarray,
+ textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]:
+ outputs = self.ort_infer([feats, feats_len, language, textnorm])
+ return outputs
--
Gitblit v1.9.1