游雁
2024-12-23 1e5ef6ed9a6f64ecca7b9ef9481519b271f793a3
bug fix
2个文件已修改
40 ■■■■■ 已修改文件
funasr/utils/export_utils.py 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/load_utils.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/export_utils.py
@@ -1,12 +1,10 @@
import os
import torch
import functools
import onnx
from onnxconverter_common import float16
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore")
def export(
@@ -44,8 +42,7 @@
                print(f"export_dir: {export_dir}")
                _torchscripts(m, path=export_dir, device="cuda")
        elif type=='onnx_fp16':
        elif type == "onnx_fp16":
            assert (
                torch.cuda.is_available()
            ), "Currently onnx_fp16 optimization for FunASR only supports GPU"  
@@ -73,7 +70,6 @@
    else:
        dummy_input = tuple([input.to(device) for input in dummy_input])
    verbose = kwargs.get("verbose", False)
    if isinstance(model.export_name, str):
@@ -94,8 +90,13 @@
    )
    if quantize:
        try:
        from onnxruntime.quantization import QuantType, quantize_dynamic
        import onnx
        except:
            raise RuntimeError(
                "You are quantizing the onnx model, please install onnxruntime first. via \n`pip install onnx`\n`pip install onnxruntime`."
            )
        quant_model_path = model_path.replace(".onnx", "_quant.onnx")
        onnx_model = onnx.load(model_path)
@@ -129,7 +130,9 @@
    if isinstance(model.export_name, str):
        model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript")))
    else:
        model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")))
        model_script.save(
            os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript"))
        )
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
@@ -225,7 +228,6 @@
    model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))
def _onnx_opt_for_encdec(model, path, enable_fp16):
    # Get input data
@@ -269,14 +271,17 @@
            dynamic_axes=model.export_dynamic_axes(),
        )    
    # fp32 to fp16
    fp16_model_path = f"{path}/{model.export_name}_hook_fp16.onnx"
    print("*" * 50)
    print(f"[_onnx_opt_for_encdec(fp16)]: {fp16_model_path}\n\n")
    if os.path.exists(fp32_model_path) and not os.path.exists(fp16_model_path):
        try:
            from onnxconverter_common import float16
        except:
            raise RuntimeError(
                "You are converting the onnx model to fp16, please install onnxconverter-common first. via `pip install onnxconverter-common`."
            )
        fp32_onnx_model = onnx.load(fp32_model_path)
        fp16_onnx_model = float16.convert_float_to_float16(fp32_onnx_model, keep_io_types=True)
        onnx.save(
            fp16_onnx_model, fp16_model_path
        )
        onnx.save(fp16_onnx_model, fp16_model_path)
funasr/utils/load_utils.py
@@ -10,7 +10,6 @@
import time
import logging
from torch.nn.utils.rnn import pad_sequence
from pydub import AudioSegment
try:
    from funasr.download.file import download_from_url
@@ -19,6 +18,11 @@
import pdb
import subprocess
from subprocess import CalledProcessError, run
try:
    from pydub import AudioSegment
except:
    pass
def is_ffmpeg_installed():
@@ -166,7 +170,12 @@
    byte_data = BytesIO(input)
    # 使用 pydub 加载音频
    try:
    audio = AudioSegment.from_file(byte_data)
    except:
        raise RuntimeError(
            "You are decoding the pcm data, please install pydub first. via `pip install pydub`."
        )
    # 确保采样率为 16000 Hz
    if audio.frame_rate != fs: