kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
@@ -3,80 +3,83 @@
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import json
import copy
import librosa
import numpy as np
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                          OrtInferSession, TokenIDConverter, get_logger,
                          read_yaml)
from .utils.utils import (
    CharTokenizer,
    Hypothesis,
    ONNXRuntimeError,
    OrtInferSession,
    TokenIDConverter,
    get_logger,
    read_yaml,
)
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontendOnline, SinusoidalPositionEncoderOnline
logging = get_logger()
class Paraformer():
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 chunk_size: List = [5, 10, 5],
                 device_id: Union[str, int] = "-1",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 ):
class Paraformer:
    def __init__(
        self,
        model_dir: Union[str, Path] = None,
        batch_size: int = 1,
        chunk_size: List = [5, 10, 5],
        device_id: Union[str, int] = "-1",
        quantize: bool = False,
        intra_op_num_threads: int = 4,
        cache_dir: str = None,
        **kwargs,
    ):
        if not Path(model_dir).exists():
            try:
                from modelscope.hub.snapshot_download import snapshot_download
            except:
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
                      "\npip3 install -U modelscope\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
        encoder_model_file = os.path.join(model_dir, 'model.onnx')
        decoder_model_file = os.path.join(model_dir, 'decoder.onnx')
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
                    model_dir
                )
        encoder_model_file = os.path.join(model_dir, "model.onnx")
        decoder_model_file = os.path.join(model_dir, "decoder.onnx")
        if quantize:
            encoder_model_file = os.path.join(model_dir, 'model_quant.onnx')
            decoder_model_file = os.path.join(model_dir, 'decoder_quant.onnx')
            encoder_model_file = os.path.join(model_dir, "model_quant.onnx")
            decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
        if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
            print(".onnx is not exist, begin to export onnx")
            print(".onnx does not exist, begin to export onnx")
            try:
                from funasr.export.export_model import ModelExport
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
                      "\npip3 install -U funasr\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            export_model = ModelExport(
                cache_dir=cache_dir,
                onnx=True,
                device="cpu",
                quant=quantize,
            )
            export_model.export(model_dir)
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
            model = AutoModel(model=model_dir)
            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
        config_file = os.path.join(model_dir, "config.yaml")
        cmvn_file = os.path.join(model_dir, "am.mvn")
        config = read_yaml(config_file)
        token_list = os.path.join(model_dir, "tokens.json")
        with open(token_list, "r", encoding="utf-8") as f:
            token_list = json.load(f)
        self.converter = TokenIDConverter(config['token_list'])
        self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontendOnline(
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.frontend = WavFrontendOnline(cmvn_file=cmvn_file, **config["frontend_conf"])
        self.pe = SinusoidalPositionEncoderOnline()
        self.ort_encoder_infer = OrtInferSession(encoder_model_file, device_id,
                                                 intra_op_num_threads=intra_op_num_threads)
        self.ort_decoder_infer = OrtInferSession(decoder_model_file, device_id,
                                                 intra_op_num_threads=intra_op_num_threads)
        self.ort_encoder_infer = OrtInferSession(
            encoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
        )
        self.ort_decoder_infer = OrtInferSession(
            decoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
        )
        self.batch_size = batch_size
        self.chunk_size = chunk_size
        self.encoder_output_size = config["encoder_conf"]["output_size"]
@@ -95,7 +98,9 @@
        cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
        cache["chunk_size"] = self.chunk_size
        cache["last_chunk"] = False
        cache["feats"] = np.zeros((batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)).astype(np.float32)
        cache["feats"] = np.zeros(
            (batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)
        ).astype(np.float32)
        cache["decoder_fsmn"] = []
        for i in range(self.fsmn_layer):
            fsmn_cache = np.zeros((batch_size, self.fsmn_dims, self.fsmn_lorder)).astype(np.float32)
@@ -108,31 +113,31 @@
        # process last chunk
        overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
        if cache["is_final"]:
            cache["feats"] = overlap_feats[:, -self.chunk_size[0]:, :]
            cache["feats"] = overlap_feats[:, -self.chunk_size[0] :, :]
            if not cache["last_chunk"]:
               padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
               overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
                padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
                overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
        else:
            cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]):, :]
            cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]) :, :]
        return overlap_feats
    def __call__(self, audio_in: np.ndarray, **kwargs):
        waveforms = np.expand_dims(audio_in, axis=0)
        param_dict = kwargs.get('param_dict', dict())
        is_final = param_dict.get('is_final', False)
        cache = param_dict.get('cache', dict())
        param_dict = kwargs.get("param_dict", dict())
        is_final = param_dict.get("is_final", False)
        cache = param_dict.get("cache", dict())
        asr_res = []
        if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
            cache["last_chunk"] = True
            feats = cache["feats"]
            feats_len = np.array([feats.shape[1]]).astype(np.int32)
            asr_res = self.infer(feats, feats_len, cache)
            return asr_res
        feats, feats_len = self.extract_feat(waveforms, is_final)
        if feats.shape[1] != 0:
            feats *= self.encoder_output_size ** 0.5
            feats *= self.encoder_output_size**0.5
            cache = self.prepare_cache(cache)
            cache["is_final"] = is_final
@@ -145,16 +150,19 @@
                    feats = self.add_overlap_chunk(feats, cache)
                else:
                    # first chunk
                    feats_chunk1 = self.add_overlap_chunk(feats[:, :self.chunk_size[1], :], cache)
                    feats_chunk1 = self.add_overlap_chunk(feats[:, : self.chunk_size[1], :], cache)
                    feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
                    asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
                    # last chunk
                    cache["last_chunk"] = True
                    feats_chunk2 = self.add_overlap_chunk(feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]):, :], cache)
                    feats_chunk2 = self.add_overlap_chunk(
                        feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]) :, :],
                        cache,
                    )
                    feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
                    asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
                    asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
                    res = {}
                    for pred in asr_res_chunk:
@@ -188,18 +196,36 @@
            dec_input.extend(cache["decoder_fsmn"])
            dec_output = self.ort_decoder_infer(dec_input)
            logits, sample_ids, cache["decoder_fsmn"] = dec_output[0], dec_output[1], dec_output[2:]
            cache["decoder_fsmn"] = [item[:, :, -self.fsmn_lorder:] for item in cache["decoder_fsmn"]]
            cache["decoder_fsmn"] = [
                item[:, :, -self.fsmn_lorder :] for item in cache["decoder_fsmn"]
            ]
            preds = self.decode(logits, acoustic_embeds_len)
            for pred in preds:
                pred = sentence_postprocess(pred)
                asr_res.append({'preds': pred})
                asr_res.append({"preds": pred})
        return asr_res
    def load_data(self,
                  wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
    def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
        def convert_to_wav(input_path, output_path):
            from pydub import AudioSegment
            try:
                audio = AudioSegment.from_mp3(input_path)
                audio.export(output_path, format="wav")
                print("音频文件为mp3格式,已转换为wav格式")
            except Exception as e:
                print(f"转换失败:{e}")
        def load_wav(path: str) -> np.ndarray:
            if not path.lower().endswith('.wav'):
                import os
                input_path = path
                path = os.path.splitext(path)[0]+'.wav'
                convert_to_wav(input_path,path) #将mp3格式转换成wav格式
            waveform, _ = librosa.load(path, sr=fs)
            return waveform
@@ -212,12 +238,11 @@
        if isinstance(wav_content, list):
            return [load_wav(path) for path in wav_content]
        raise TypeError(
            f'The type of {wav_content} is not in [str, np.ndarray, list]')
        raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
    def extract_feat(self,
                     waveforms: np.ndarray, is_final: bool = False
                     ) -> Tuple[np.ndarray, np.ndarray]:
    def extract_feat(
        self, waveforms: np.ndarray, is_final: bool = False
    ) -> Tuple[np.ndarray, np.ndarray]:
        waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
        for idx, waveform in enumerate(waveforms):
            waveforms_lens[idx] = waveform.shape[-1]
@@ -226,12 +251,12 @@
        return feats.astype(np.float32), feats_len.astype(np.int32)
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
        return [self.decode_one(am_score, token_num)
                for am_score, token_num in zip(am_scores, token_nums)]
        return [
            self.decode_one(am_score, token_num)
            for am_score, token_num in zip(am_scores, token_nums)
        ]
    def decode_one(self,
                   am_score: np.ndarray,
                   valid_token_num: int) -> List[str]:
    def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
        yseq = am_score.argmax(axis=-1)
        score = am_score.max(axis=-1)
        score = np.sum(score, axis=-1)
@@ -261,15 +286,15 @@
        list_frames = []
        cache_alphas = []
        cache_hiddens = []
        alphas[:, :self.chunk_size[0]] = 0.0
        alphas[:, sum(self.chunk_size[:2]):] = 0.0
        alphas[:, : self.chunk_size[0]] = 0.0
        alphas[:, sum(self.chunk_size[:2]) :] = 0.0
        if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
            hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
            alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
        if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
            tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
            tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
            tail_alphas =np.tile(tail_alphas, (batch_size, 1))
            tail_alphas = np.tile(tail_alphas, (batch_size, 1))
            hidden = np.concatenate((hidden, tail_hidden), axis=1)
            alphas = np.concatenate((alphas, tail_alphas), axis=1)
@@ -317,5 +342,6 @@
        cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
        cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
        return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(np.int32)
        return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(
            np.int32
        )