From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py |  173 +++++++++++++++++++++++++++++++++------------------------
 1 files changed, 100 insertions(+), 73 deletions(-)

diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
index 6925960..3f63ea0 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
@@ -8,74 +8,78 @@
 import librosa
 import numpy as np
 
-from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
-                          OrtInferSession, TokenIDConverter, get_logger,
-                          read_yaml)
+from .utils.utils import (
+    CharTokenizer,
+    Hypothesis,
+    ONNXRuntimeError,
+    OrtInferSession,
+    TokenIDConverter,
+    get_logger,
+    read_yaml,
+)
 from .utils.postprocess_utils import sentence_postprocess
 from .utils.frontend import WavFrontendOnline, SinusoidalPositionEncoderOnline
 
 logging = get_logger()
 
 
-class Paraformer():
-    def __init__(self, model_dir: Union[str, Path] = None,
-                 batch_size: int = 1,
-                 chunk_size: List = [5, 10, 5],
-                 device_id: Union[str, int] = "-1",
-                 quantize: bool = False,
-                 intra_op_num_threads: int = 4,
-                 cache_dir: str = None
-                 ):
+class Paraformer:
+    def __init__(
+        self,
+        model_dir: Union[str, Path] = None,
+        batch_size: int = 1,
+        chunk_size: List = [5, 10, 5],
+        device_id: Union[str, int] = "-1",
+        quantize: bool = False,
+        intra_op_num_threads: int = 4,
+        cache_dir: str = None,
+        **kwargs,
+    ):
 
         if not Path(model_dir).exists():
             try:
                 from modelscope.hub.snapshot_download import snapshot_download
             except:
-                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
-                      "\npip3 install -U modelscope\n" \
-                      "For the users in China, you could install with the command:\n" \
-                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
             try:
                 model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
             except:
-                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
-        
-        encoder_model_file = os.path.join(model_dir, 'model.onnx')
-        decoder_model_file = os.path.join(model_dir, 'decoder.onnx')
+                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
+                    model_dir
+                )
+
+        encoder_model_file = os.path.join(model_dir, "model.onnx")
+        decoder_model_file = os.path.join(model_dir, "decoder.onnx")
         if quantize:
-            encoder_model_file = os.path.join(model_dir, 'model_quant.onnx')
-            decoder_model_file = os.path.join(model_dir, 'decoder_quant.onnx')
+            encoder_model_file = os.path.join(model_dir, "model_quant.onnx")
+            decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
         if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
-            print(".onnx is not exist, begin to export onnx")
+            print(".onnx does not exist, begin to export onnx")
             try:
                 from funasr import AutoModel
             except:
-                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
-                      "\npip3 install -U funasr\n" \
-                      "For the users in China, you could install with the command:\n" \
-                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
             model = AutoModel(model=model_dir)
-            model_dir = model.export(type="onnx", quantize=quantize)
+            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
 
-        config_file = os.path.join(model_dir, 'config.yaml')
-        cmvn_file = os.path.join(model_dir, 'am.mvn')
+        config_file = os.path.join(model_dir, "config.yaml")
+        cmvn_file = os.path.join(model_dir, "am.mvn")
         config = read_yaml(config_file)
-        token_list = os.path.join(model_dir, 'tokens.json')
-        with open(token_list, 'r', encoding='utf-8') as f:
+        token_list = os.path.join(model_dir, "tokens.json")
+        with open(token_list, "r", encoding="utf-8") as f:
             token_list = json.load(f)
 
         self.converter = TokenIDConverter(token_list)
         self.tokenizer = CharTokenizer()
-        self.frontend = WavFrontendOnline(
-            cmvn_file=cmvn_file,
-            **config['frontend_conf']
-        )
+        self.frontend = WavFrontendOnline(cmvn_file=cmvn_file, **config["frontend_conf"])
         self.pe = SinusoidalPositionEncoderOnline()
-        self.ort_encoder_infer = OrtInferSession(encoder_model_file, device_id,
-                                                 intra_op_num_threads=intra_op_num_threads)
-        self.ort_decoder_infer = OrtInferSession(decoder_model_file, device_id,
-                                                 intra_op_num_threads=intra_op_num_threads)
+        self.ort_encoder_infer = OrtInferSession(
+            encoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
+        )
+        self.ort_decoder_infer = OrtInferSession(
+            decoder_model_file, device_id, intra_op_num_threads=intra_op_num_threads
+        )
         self.batch_size = batch_size
         self.chunk_size = chunk_size
         self.encoder_output_size = config["encoder_conf"]["output_size"]
@@ -94,7 +98,9 @@
         cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
         cache["chunk_size"] = self.chunk_size
         cache["last_chunk"] = False
-        cache["feats"] = np.zeros((batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)).astype(np.float32)
+        cache["feats"] = np.zeros(
+            (batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)
+        ).astype(np.float32)
         cache["decoder_fsmn"] = []
         for i in range(self.fsmn_layer):
             fsmn_cache = np.zeros((batch_size, self.fsmn_dims, self.fsmn_lorder)).astype(np.float32)
@@ -107,31 +113,31 @@
         # process last chunk
         overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
         if cache["is_final"]:
-            cache["feats"] = overlap_feats[:, -self.chunk_size[0]:, :]
+            cache["feats"] = overlap_feats[:, -self.chunk_size[0] :, :]
             if not cache["last_chunk"]:
-               padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
-               overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
+                padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
+                overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
         else:
-            cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]):, :]
+            cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]) :, :]
         return overlap_feats
 
     def __call__(self, audio_in: np.ndarray, **kwargs):
         waveforms = np.expand_dims(audio_in, axis=0)
-        param_dict = kwargs.get('param_dict', dict())
-        is_final = param_dict.get('is_final', False)
-        cache = param_dict.get('cache', dict())
+        param_dict = kwargs.get("param_dict", dict())
+        is_final = param_dict.get("is_final", False)
+        cache = param_dict.get("cache", dict())
         asr_res = []
-        
+
         if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
             cache["last_chunk"] = True
             feats = cache["feats"]
             feats_len = np.array([feats.shape[1]]).astype(np.int32)
             asr_res = self.infer(feats, feats_len, cache)
             return asr_res
-            
+
         feats, feats_len = self.extract_feat(waveforms, is_final)
         if feats.shape[1] != 0:
-            feats *= self.encoder_output_size ** 0.5
+            feats *= self.encoder_output_size**0.5
             cache = self.prepare_cache(cache)
             cache["is_final"] = is_final
 
@@ -144,16 +150,19 @@
                     feats = self.add_overlap_chunk(feats, cache)
                 else:
                     # first chunk
-                    feats_chunk1 = self.add_overlap_chunk(feats[:, :self.chunk_size[1], :], cache)
+                    feats_chunk1 = self.add_overlap_chunk(feats[:, : self.chunk_size[1], :], cache)
                     feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
                     asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
 
                     # last chunk
                     cache["last_chunk"] = True
-                    feats_chunk2 = self.add_overlap_chunk(feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]):, :], cache)
+                    feats_chunk2 = self.add_overlap_chunk(
+                        feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]) :, :],
+                        cache,
+                    )
                     feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
                     asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
-                    
+
                     asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
                     res = {}
                     for pred in asr_res_chunk:
@@ -187,18 +196,36 @@
             dec_input.extend(cache["decoder_fsmn"])
             dec_output = self.ort_decoder_infer(dec_input)
             logits, sample_ids, cache["decoder_fsmn"] = dec_output[0], dec_output[1], dec_output[2:]
-            cache["decoder_fsmn"] = [item[:, :, -self.fsmn_lorder:] for item in cache["decoder_fsmn"]]
+            cache["decoder_fsmn"] = [
+                item[:, :, -self.fsmn_lorder :] for item in cache["decoder_fsmn"]
+            ]
 
             preds = self.decode(logits, acoustic_embeds_len)
             for pred in preds:
                 pred = sentence_postprocess(pred)
-                asr_res.append({'preds': pred})
+                asr_res.append({"preds": pred})
 
         return asr_res
 
-    def load_data(self,
-                  wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+    def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+        
+        def convert_to_wav(input_path, output_path):
+            from pydub import AudioSegment
+            try:
+                audio = AudioSegment.from_mp3(input_path)
+                audio.export(output_path, format="wav")
+                print("闊抽鏂囦欢涓簃p3鏍煎紡锛屽凡杞崲涓簑av鏍煎紡")
+                
+            except Exception as e:
+                print(f"杞崲澶辫触:{e}")
+
         def load_wav(path: str) -> np.ndarray:
+            if not path.lower().endswith('.wav'):
+                import os
+                input_path = path
+                path = os.path.splitext(path)[0]+'.wav'
+                convert_to_wav(input_path,path) #灏唌p3鏍煎紡杞崲鎴恮av鏍煎紡
+
             waveform, _ = librosa.load(path, sr=fs)
             return waveform
 
@@ -211,12 +238,11 @@
         if isinstance(wav_content, list):
             return [load_wav(path) for path in wav_content]
 
-        raise TypeError(
-            f'The type of {wav_content} is not in [str, np.ndarray, list]')
+        raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
 
-    def extract_feat(self,
-                     waveforms: np.ndarray, is_final: bool = False
-                     ) -> Tuple[np.ndarray, np.ndarray]:
+    def extract_feat(
+        self, waveforms: np.ndarray, is_final: bool = False
+    ) -> Tuple[np.ndarray, np.ndarray]:
         waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
         for idx, waveform in enumerate(waveforms):
             waveforms_lens[idx] = waveform.shape[-1]
@@ -225,12 +251,12 @@
         return feats.astype(np.float32), feats_len.astype(np.int32)
 
     def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
-        return [self.decode_one(am_score, token_num)
-                for am_score, token_num in zip(am_scores, token_nums)]
+        return [
+            self.decode_one(am_score, token_num)
+            for am_score, token_num in zip(am_scores, token_nums)
+        ]
 
-    def decode_one(self,
-                   am_score: np.ndarray,
-                   valid_token_num: int) -> List[str]:
+    def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
         yseq = am_score.argmax(axis=-1)
         score = am_score.max(axis=-1)
         score = np.sum(score, axis=-1)
@@ -260,15 +286,15 @@
         list_frames = []
         cache_alphas = []
         cache_hiddens = []
-        alphas[:, :self.chunk_size[0]] = 0.0
-        alphas[:, sum(self.chunk_size[:2]):] = 0.0
+        alphas[:, : self.chunk_size[0]] = 0.0
+        alphas[:, sum(self.chunk_size[:2]) :] = 0.0
         if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
             hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
             alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
         if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
             tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
             tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
-            tail_alphas =np.tile(tail_alphas, (batch_size, 1))
+            tail_alphas = np.tile(tail_alphas, (batch_size, 1))
             hidden = np.concatenate((hidden, tail_hidden), axis=1)
             alphas = np.concatenate((alphas, tail_alphas), axis=1)
 
@@ -316,5 +342,6 @@
         cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
         cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
 
-        return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(np.int32)
-
+        return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(
+            np.int32
+        )

--
Gitblit v1.9.1