维石
2024-07-22 2ae59b6ce06305724e2eaf30b9f9e93447a7832e
ONNX and torchscript export for sensevoice
4个文件已修改
5个文件已添加
428 ■■■■■ 已修改文件
examples/industrial_data_pretraining/sense_voice/export.py 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/sense_voice/export_meta.py 50 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/export_utils.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/libtorch/demo_sensevoicesmall.py 38 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/libtorch/funasr_torch/__init__.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/libtorch/funasr_torch/sensevoice_bin.py 130 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/demo_sencevoicesmall.py 38 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/funasr_onnx/__init__.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py 145 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/sense_voice/export.py
New file
@@ -0,0 +1,15 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr import AutoModel
model_dir = "iic/SenseVoiceSmall"
model = AutoModel(
    model=model_dir,
    device="cuda:0",
)
res = model.export(type="onnx", quantize=False)
funasr/models/sense_voice/export_meta.py
@@ -5,30 +5,19 @@
import types
import torch
import torch.nn as nn
from funasr.register import tables
from funasr.utils.torch_function import sequence_mask
def export_rebuild_model(model, **kwargs):
    model.device = kwargs.get("device")
    is_onnx = kwargs.get("type", "onnx") == "onnx"
    # encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
    # model.encoder = encoder_class(model.encoder, onnx=is_onnx)
    from funasr.utils.torch_function import sequence_mask
    model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
    model.forward = types.MethodType(export_forward, model)
    model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
    model.export_input_names = types.MethodType(export_input_names, model)
    model.export_output_names = types.MethodType(export_output_names, model)
    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
    model.export_name = types.MethodType(export_name, model)
    model.export_name = "model"
    return model
def export_forward(
    self,
@@ -38,32 +27,28 @@
    textnorm: torch.Tensor,
    **kwargs,
):
    speech = speech.to(device=kwargs["device"])
    speech_lengths = speech_lengths.to(device=kwargs["device"])
    # speech = speech.to(device="cuda")
    # speech_lengths = speech_lengths.to(device="cuda")
    language_query = self.embed(language.to(speech.device)).unsqueeze(1)
    textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1)
    language_query = self.embed(language).to(speech.device)
    textnorm_query = self.embed(textnorm).to(speech.device)
    speech = torch.cat((textnorm_query, speech), dim=1)
    speech_lengths += 1
    event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
        speech.size(0), 1, 1
    )
    input_query = torch.cat((language_query, event_emo_query), dim=1)
    speech = torch.cat((input_query, speech), dim=1)
    speech_lengths += 3
    # Encoder
    encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
    speech_lengths_new = speech_lengths + 4
    encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths_new)
    if isinstance(encoder_out, tuple):
        encoder_out = encoder_out[0]
    # c. Passed the encoder result and the beam search
    ctc_logits = self.ctc.log_softmax(encoder_out)
    ctc_logits = self.ctc.ctc_lo(encoder_out)
    return ctc_logits, encoder_out_lens
def export_dummy_inputs(self):
    speech = torch.randn(2, 30, 560)
@@ -72,26 +57,21 @@
    textnorm = torch.tensor([15, 15], dtype=torch.int32)
    return (speech, speech_lengths, language, textnorm)
def export_input_names(self):
    return ["speech", "speech_lengths", "language", "textnorm"]
def export_output_names(self):
    return ["ctc_logits", "encoder_out_lens"]
def export_dynamic_axes(self):
    return {
        "speech": {0: "batch_size", 1: "feats_length"},
        "speech_lengths": {
            0: "batch_size",
        },
        "logits": {0: "batch_size", 1: "logits_length"},
        "speech_lengths": {0: "batch_size"},
        "language": {0: "batch_size"},
        "textnorm": {0: "batch_size"},
        "ctc_logits": {0: "batch_size", 1: "logits_length"},
        "encoder_out_lens":  {0: "batch_size"},
    }
def export_name(
    self,
):
def export_name(self):
    return "model.onnx"
funasr/utils/export_utils.py
@@ -54,7 +54,10 @@
    verbose = kwargs.get("verbose", False)
    if isinstance(model.export_name, str):
    export_name = model.export_name + ".onnx"
    else:
        export_name = model.export_name()
    model_path = os.path.join(export_dir, export_name)
    torch.onnx.export(
        model,
@@ -72,12 +75,12 @@
        import onnx
        quant_model_path = model_path.replace(".onnx", "_quant.onnx")
        if not os.path.exists(quant_model_path):
            onnx_model = onnx.load(model_path)
            nodes = [n.name for n in onnx_model.graph.node]
            nodes_to_exclude = [
                m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
            ]
        print("Quantizing model from {} to {}".format(model_path, quant_model_path))
            quantize_dynamic(
                model_input=model_path,
                model_output=quant_model_path,
@@ -100,7 +103,10 @@
            dummy_input = tuple([i.cuda() for i in dummy_input])
    model_script = torch.jit.trace(model, dummy_input)
    model_script.save(os.path.join(path, f"{model.export_name}.torchscript"))
    if isinstance(model.export_name, str):
        model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript")))
    else:
        model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")))
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
runtime/python/libtorch/demo_sensevoicesmall.py
New file
@@ -0,0 +1,38 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
import torch
from pathlib import Path
from funasr import AutoModel
from funasr_torch import SenseVoiceSmallTorchScript as SenseVoiceSmall
from funasr.utils.postprocess_utils import rich_transcription_postprocess
model_dir = "iic/SenseVoiceSmall"
model = AutoModel(
    model=model_dir,
    device="cuda:0",
)
# res = model.export(type="torchscript", quantize=False)
# export model init
model_path = "{}/.cache/modelscope/hub/{}".format(Path.home(), model_dir)
model_bin = SenseVoiceSmall(model_path)
# build tokenizer
try:
    from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
    tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"))
except:
    tokenizer = None
# inference
wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav"
language_list = [0]
textnorm_list = [15]
res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer)
print([rich_transcription_postprocess(i) for i in res])
runtime/python/libtorch/funasr_torch/__init__.py
@@ -1,2 +1,3 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer
from .sensevoice_bin import SenseVoiceSmallTorchScript
runtime/python/libtorch/funasr_torch/sensevoice_bin.py
New file
@@ -0,0 +1,130 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import os.path
import librosa
import numpy as np
from pathlib import Path
from typing import List, Union, Tuple
from .utils.utils import (
    CharTokenizer,
    get_logger,
    read_yaml,
)
from .utils.frontend import WavFrontend
logging = get_logger()
class SenseVoiceSmallTorchScript:
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
    def __init__(
        self,
        model_dir: Union[str, Path] = None,
        batch_size: int = 1,
        device_id: Union[str, int] = "-1",
        plot_timestamp_to: str = "",
        quantize: bool = False,
        intra_op_num_threads: int = 4,
        cache_dir: str = None,
        **kwargs,
    ):
        if quantize:
            model_file = os.path.join(model_dir, "model_quant.torchscript")
        else:
            model_file = os.path.join(model_dir, "model.torchscript")
        config_file = os.path.join(model_dir, "config.yaml")
        cmvn_file = os.path.join(model_dir, "am.mvn")
        config = read_yaml(config_file)
        # token_list = os.path.join(model_dir, "tokens.json")
        # with open(token_list, "r", encoding="utf-8") as f:
        #     token_list = json.load(f)
        # self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        config["frontend_conf"]['cmvn_file'] = cmvn_file
        self.frontend = WavFrontend(**config["frontend_conf"])
        self.ort_infer = torch.jit.load(model_file)
        self.batch_size = batch_size
        self.blank_id = 0
    def __call__(self,
                 wav_content: Union[str, np.ndarray, List[str]],
                 language: List,
                 textnorm: List,
                 tokenizer=None,
                 **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        asr_res = []
        for beg_idx in range(0, waveform_nums, self.batch_size):
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            ctc_logits, encoder_out_lens = self.ort_infer(torch.Tensor(feats),
                                                          torch.Tensor(feats_len),
                                                          torch.tensor(language),
                                                          torch.tensor(textnorm)
                                                          )
            # support batch_size=1 only currently
            x = ctc_logits[0, : encoder_out_lens[0].item(), :]
            yseq = x.argmax(dim=-1)
            yseq = torch.unique_consecutive(yseq, dim=-1)
            mask = yseq != self.blank_id
            token_int = yseq[mask].tolist()
            if tokenizer is not None:
                asr_res.append(tokenizer.tokens2text(token_int))
            else:
                asr_res.append(token_int)
        return asr_res
    def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
        def load_wav(path: str) -> np.ndarray:
            waveform, _ = librosa.load(path, sr=fs)
            return waveform
        if isinstance(wav_content, np.ndarray):
            return [wav_content]
        if isinstance(wav_content, str):
            return [load_wav(wav_content)]
        if isinstance(wav_content, list):
            return [load_wav(path) for path in wav_content]
        raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
    def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
        feats, feats_len = [], []
        for waveform in waveform_list:
            speech, _ = self.frontend.fbank(waveform)
            feat, feat_len = self.frontend.lfr_cmvn(speech)
            feats.append(feat)
            feats_len.append(feat_len)
        feats = self.pad_feats(feats, np.max(feats_len))
        feats_len = np.array(feats_len).astype(np.int32)
        return feats, feats_len
    @staticmethod
    def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
        def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
            pad_width = ((0, max_feat_len - cur_len), (0, 0))
            return np.pad(feat, pad_width, "constant", constant_values=0)
        feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
        feats = np.array(feat_res).astype(np.float32)
        return feats
runtime/python/onnxruntime/demo_sencevoicesmall.py
New file
@@ -0,0 +1,38 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
import torch
from pathlib import Path
from funasr import AutoModel
from funasr_onnx import SenseVoiceSmallONNX as SenseVoiceSmall
from funasr.utils.postprocess_utils import rich_transcription_postprocess
model_dir = "iic/SenseVoiceSmall"
model = AutoModel(
    model=model_dir,
    device="cuda:0",
)
res = model.export(type="onnx", quantize=False)
# export model init
model_path = "{}/.cache/modelscope/hub/{}".format(Path.home(), model_dir)
model_bin = SenseVoiceSmall(model_path)
# build tokenizer
try:
    from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
    tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"))
except:
    tokenizer = None
# inference
wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav"
language_list = [0]
textnorm_list = [15]
res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer)
print([rich_transcription_postprocess(i) for i in res])
runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -4,3 +4,4 @@
from .vad_bin import Fsmn_vad_online
from .punc_bin import CT_Transformer
from .punc_bin import CT_Transformer_VadRealtime
from .sensevoice_bin import SenseVoiceSmallONNX
runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py
New file
@@ -0,0 +1,145 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import os.path
import librosa
import numpy as np
from pathlib import Path
from typing import List, Union, Tuple
from .utils.utils import (
    CharTokenizer,
    Hypothesis,
    ONNXRuntimeError,
    OrtInferSession,
    TokenIDConverter,
    get_logger,
    read_yaml,
)
from .utils.frontend import WavFrontend
logging = get_logger()
class SenseVoiceSmallONNX:
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
    def __init__(
        self,
        model_dir: Union[str, Path] = None,
        batch_size: int = 1,
        device_id: Union[str, int] = "-1",
        plot_timestamp_to: str = "",
        quantize: bool = False,
        intra_op_num_threads: int = 4,
        cache_dir: str = None,
        **kwargs,
    ):
        if quantize:
            model_file = os.path.join(model_dir, "model_quant.onnx")
        else:
            model_file = os.path.join(model_dir, "model.onnx")
        config_file = os.path.join(model_dir, "config.yaml")
        cmvn_file = os.path.join(model_dir, "am.mvn")
        config = read_yaml(config_file)
        # token_list = os.path.join(model_dir, "tokens.json")
        # with open(token_list, "r", encoding="utf-8") as f:
        #     token_list = json.load(f)
        # self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        config["frontend_conf"]['cmvn_file'] = cmvn_file
        self.frontend = WavFrontend(**config["frontend_conf"])
        self.ort_infer = OrtInferSession(
            model_file, device_id, intra_op_num_threads=intra_op_num_threads
        )
        self.batch_size = batch_size
        self.blank_id = 0
    def __call__(self,
                 wav_content: Union[str, np.ndarray, List[str]],
                 language: List,
                 textnorm: List,
                 tokenizer=None,
                 **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        asr_res = []
        for beg_idx in range(0, waveform_nums, self.batch_size):
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            ctc_logits, encoder_out_lens = self.infer(feats,
                                 feats_len,
                                 np.array(language, dtype=np.int32),
                                 np.array(textnorm, dtype=np.int32)
                                 )
            # back to torch.Tensor
            ctc_logits = torch.from_numpy(ctc_logits).float()
            # support batch_size=1 only currently
            x = ctc_logits[0, : encoder_out_lens[0].item(), :]
            yseq = x.argmax(dim=-1)
            yseq = torch.unique_consecutive(yseq, dim=-1)
            mask = yseq != self.blank_id
            token_int = yseq[mask].tolist()
            if tokenizer is not None:
                asr_res.append(tokenizer.tokens2text(token_int))
            else:
                asr_res.append(token_int)
        return asr_res
    def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
        def load_wav(path: str) -> np.ndarray:
            waveform, _ = librosa.load(path, sr=fs)
            return waveform
        if isinstance(wav_content, np.ndarray):
            return [wav_content]
        if isinstance(wav_content, str):
            return [load_wav(wav_content)]
        if isinstance(wav_content, list):
            return [load_wav(path) for path in wav_content]
        raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
    def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
        feats, feats_len = [], []
        for waveform in waveform_list:
            speech, _ = self.frontend.fbank(waveform)
            feat, feat_len = self.frontend.lfr_cmvn(speech)
            feats.append(feat)
            feats_len.append(feat_len)
        feats = self.pad_feats(feats, np.max(feats_len))
        feats_len = np.array(feats_len).astype(np.int32)
        return feats, feats_len
    @staticmethod
    def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
        def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
            pad_width = ((0, max_feat_len - cur_len), (0, 0))
            return np.pad(feat, pad_width, "constant", constant_values=0)
        feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
        feats = np.array(feat_res).astype(np.float32)
        return feats
    def infer(self,
              feats: np.ndarray,
              feats_len: np.ndarray,
              language: np.ndarray,
              textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer([feats, feats_len, language, textnorm])
        return outputs