游雁
2024-12-23 1e5ef6ed9a6f64ecca7b9ef9481519b271f793a3
bug fix
2个文件已修改
56 ■■■■■ 已修改文件
funasr/utils/export_utils.py 43 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/load_utils.py 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/export_utils.py
@@ -1,12 +1,10 @@
import os
import torch
import functools
import onnx
from onnxconverter_common import float16
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore")
def export(
@@ -44,14 +42,13 @@
                print(f"export_dir: {export_dir}")
                _torchscripts(m, path=export_dir, device="cuda")
        elif type=='onnx_fp16':
        elif type == "onnx_fp16":
            assert (
                torch.cuda.is_available()
            ), "Currently onnx_fp16 optimization for FunASR only supports GPU"
            ), "Currently onnx_fp16 optimization for FunASR only supports GPU"
            if hasattr(m, "encoder") and hasattr(m, "decoder"):
                _onnx_opt_for_encdec(m, path=export_dir, enable_fp16=True)
                _onnx_opt_for_encdec(m, path=export_dir, enable_fp16=True)
    return export_dir
@@ -73,7 +70,6 @@
    else:
        dummy_input = tuple([input.to(device) for input in dummy_input])
    verbose = kwargs.get("verbose", False)
    if isinstance(model.export_name, str):
@@ -94,8 +90,13 @@
    )
    if quantize:
        from onnxruntime.quantization import QuantType, quantize_dynamic
        import onnx
        try:
            from onnxruntime.quantization import QuantType, quantize_dynamic
            import onnx
        except:
            raise RuntimeError(
                "You are quantizing the onnx model, please install onnxruntime first. via \n`pip install onnx`\n`pip install onnxruntime`."
            )
        quant_model_path = model_path.replace(".onnx", "_quant.onnx")
        onnx_model = onnx.load(model_path)
@@ -117,19 +118,21 @@
def _torchscripts(model, path, device="cuda"):
    dummy_input = model.export_dummy_inputs()
    if device == "cuda":
        model = model.cuda()
        if isinstance(dummy_input, torch.Tensor):
            dummy_input = dummy_input.cuda()
        else:
            dummy_input = tuple([i.cuda() for i in dummy_input])
    model_script = torch.jit.trace(model, dummy_input)
    if isinstance(model.export_name, str):
        model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript")))
    else:
        model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")))
        model_script.save(
            os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript"))
        )
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
@@ -225,7 +228,6 @@
    model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))
def _onnx_opt_for_encdec(model, path, enable_fp16):
    # Get input data
@@ -267,16 +269,19 @@
            input_names=model.export_input_names(),
            output_names=model.export_output_names(),
            dynamic_axes=model.export_dynamic_axes(),
        )
        )
    # fp32 to fp16
    fp16_model_path = f"{path}/{model.export_name}_hook_fp16.onnx"
    print("*" * 50)
    print(f"[_onnx_opt_for_encdec(fp16)]: {fp16_model_path}\n\n")
    if os.path.exists(fp32_model_path) and not os.path.exists(fp16_model_path):
        try:
            from onnxconverter_common import float16
        except:
            raise RuntimeError(
                "You are converting the onnx model to fp16, please install onnxconverter-common first. via `pip install onnxconverter-common`."
            )
        fp32_onnx_model = onnx.load(fp32_model_path)
        fp16_onnx_model = float16.convert_float_to_float16(fp32_onnx_model, keep_io_types=True)
        onnx.save(
            fp16_onnx_model, fp16_model_path
        )
        onnx.save(fp16_onnx_model, fp16_model_path)
funasr/utils/load_utils.py
@@ -10,7 +10,6 @@
import time
import logging
from torch.nn.utils.rnn import pad_sequence
from pydub import AudioSegment
try:
    from funasr.download.file import download_from_url
@@ -19,6 +18,11 @@
import pdb
import subprocess
from subprocess import CalledProcessError, run
try:
    from pydub import AudioSegment
except:
    pass
def is_ffmpeg_installed():
@@ -166,7 +170,12 @@
    byte_data = BytesIO(input)
    # 使用 pydub 加载音频
    audio = AudioSegment.from_file(byte_data)
    try:
        audio = AudioSegment.from_file(byte_data)
    except:
        raise RuntimeError(
            "You are decoding the pcm data, please install pydub first. via `pip install pydub`."
        )
    # 确保采样率为 16000 Hz
    if audio.frame_rate != fs: