游雁
2023-03-17 dc5367bbf12ad99a0df242506429f33554ccdea5
rtf benchmark
6个文件已修改
133 ■■■■■ 已修改文件
funasr/export/README.md 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/utils/test_rtf.py 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/utils/test_rtf.sh 109 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/README.md
@@ -50,6 +50,3 @@
```shell
python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch
```
## Acknowledge
1. We acknowledge
funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py
@@ -25,6 +25,7 @@
                 plot_timestamp_to: str = "",
                 pred_bias: int = 1,
                 quantize: bool = False,
                 intra_op_num_threads: int = 1,
                 ):
        if not Path(model_dir).exists():
funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py
@@ -27,6 +27,7 @@
                 plot_timestamp_to: str = "",
                 pred_bias: int = 1,
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 ):
        if not Path(model_dir).exists():
@@ -45,7 +46,7 @@
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.ort_infer = OrtInferSession(model_file, device_id)
        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = batch_size
        self.plot_timestamp_to = plot_timestamp_to
        self.pred_bias = pred_bias
funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
@@ -147,10 +147,10 @@
class OrtInferSession():
    def __init__(self, model_file, device_id=-1):
    def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
        device_id = str(device_id)
        sess_opt = SessionOptions()
        sess_opt.intra_op_num_threads = 4
        sess_opt.intra_op_num_threads = intra_op_num_threads
        sess_opt.log_severity_level = 4
        sess_opt.enable_cpu_mem_arena = False
        sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
funasr/runtime/python/utils/test_rtf.py
@@ -8,23 +8,24 @@
parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', type=str, required=True)
parser.add_argument('--backend', type=str, default='onnx', help='["onnx", "torch"]')
parser.add_argument('--wav_file', type=int, default=0, help='amp fallback number')
parser.add_argument('--wav_file', type=str, default=None, help='amp fallback number')
parser.add_argument('--quantize', type=bool, default=False, help='quantized model')
parser.add_argument('--intra_op_num_threads', type=int, default=1, help='intra_op_num_threads for onnx')
args = parser.parse_args()
from torch_paraformer import Paraformer
if args.backend == "onnxruntime":
    from rapid_paraformer import Paraformer
from funasr.runtime.python.libtorch.torch_paraformer import Paraformer
if args.backend == "onnx":
    from funasr.runtime.python.onnxruntime.rapid_paraformer import Paraformer
    
model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize)
model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
wav_file_f = open(args.wav_file, 'r')
wav_files = wav_file_f.readlines()
# warm-up
total = 0.0
num = 100
num = 30
wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
for i in range(num):
    beg_time = time.time()
funasr/runtime/python/utils/test_rtf.sh
@@ -1,74 +1,95 @@
nj=64
#:<<!
backend=libtorch
model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/libtorch"
tag=${backend}_fp32
quantize='False'
!
:<<!
backend=libtorch
model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/libtorch_fb20"
tag=${backend}_amp_fb20
quantize='True'
!
:<<!
backend=onnxruntime
model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/onnx"
tag=${backend}_fp32
quantize='False'
!
:<<!
backend=onnxruntime
model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/onnx_dynamic"
tag=${backend}_fp32
quantize='True'
!
stage=0
scp=/nfs/haoneng.lhn/funasr_data/aishell-1/data/test/wav.scp
local_scp_dir=/nfs/zhifu.gzf/data_debug/test/${tag}/split$nj
logs_outputs_dir=/nfs/zhifu.gzf/data_debug/test/${tag}/split$nj
split_scps_tool=../../../egs/aishell/transformer/utils/split_scp.pl
rtf_tool=test_rtf.py
mkdir -p ${local_scp_dir}
echo ${local_scp_dir}
##:<<!
#backend=libtorch
#model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/libtorch"
#tag=${backend}_fp32
#quantize='False'
#!
#
#:<<!
#backend=libtorch
#model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/libtorch_fb20"
#tag=${backend}_amp_fb20
#quantize='True'
#!
#
#:<<!
#backend=onnxruntime
#model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/onnx"
#tag=${backend}_fp32
#quantize='False'
#!
#
#:<<!
#backend=onnxruntime
#model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/onnx_dynamic"
#tag=${backend}_fp32
#quantize='True'
#!
#:<<!
model_name="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
export_root="/nfs/zhifu.gzf/export"
backend=onnx
quantize='True'
tag=${model_name}/${backend}_${quantize}
!
mkdir -p ${logs_outputs_dir}
echo ${logs_outputs_dir}
if [ $stage == 0 ];then
  if [ $quantize == 'True' ];then
    python -m funasr.export.export_model --model-name ${model_name} --export-dir ${export_root} --type ${backend} --quantize --audio_in ${scp}
  else
    python -m funasr.export.export_model --model-name ${model_name} --export-dir ${export_root} --type ${backend}
  fi
fi
model_dir=${export_root}/${model_name}
split_scps=""
for JOB in $(seq ${nj}); do
    split_scps="$split_scps $local_scp_dir/wav.$JOB.scp"
    split_scps="$split_scps $logs_outputs_dir/wav.$JOB.scp"
done
perl ../../../egs/aishell/transformer/utils/split_scp.pl $scp ${split_scps}
perl ${split_scps_tool} $scp ${split_scps}
for JOB in $(seq ${nj}); do
  {
    core_id=`expr $JOB - 1`
    taskset -c ${core_id} python ${rtf_tool} --backend ${backend} --model_dir ${model_dir} --wav_file ${local_scp_dir}/wav.$JOB.scp --quantize ${quantize} &> ${local_scp_dir}/log.$JOB.txt
    taskset -c ${core_id} python ${rtf_tool} --backend ${backend} --model_dir ${model_dir} --wav_file ${logs_outputs_dir}/wav.$JOB.scp --quantize ${quantize} &> ${logs_outputs_dir}/log.$JOB.txt
  }&
done
wait
rm -rf ${local_scp_dir}/total_time_comput.txt
rm -rf ${local_scp_dir}/total_time_wav.txt
rm -rf ${local_scp_dir}/total_rtf.txt
rm -rf ${logs_outputs_dir}/total_time_comput.txt
rm -rf ${logs_outputs_dir}/total_time_wav.txt
rm -rf ${logs_outputs_dir}/total_rtf.txt
for JOB in $(seq ${nj}); do
  {
    cat ${local_scp_dir}/log.$JOB.txt | grep "total_time_comput" | awk -F ' '  '{print $2}' >> ${local_scp_dir}/total_time_comput.txt
    cat ${local_scp_dir}/log.$JOB.txt | grep "total_time_wav" | awk -F ' '  '{print $2}' >> ${local_scp_dir}/total_time_wav.txt
    cat ${local_scp_dir}/log.$JOB.txt | grep "total_rtf" | awk -F ' '  '{print $2}' >> ${local_scp_dir}/total_rtf.txt
    cat ${logs_outputs_dir}/log.$JOB.txt | grep "total_time_comput" | awk -F ' '  '{print $2}' >> ${logs_outputs_dir}/total_time_comput.txt
    cat ${logs_outputs_dir}/log.$JOB.txt | grep "total_time_wav" | awk -F ' '  '{print $2}' >> ${logs_outputs_dir}/total_time_wav.txt
    cat ${logs_outputs_dir}/log.$JOB.txt | grep "total_rtf" | awk -F ' '  '{print $2}' >> ${logs_outputs_dir}/total_rtf.txt
  }
done
total_time_comput=`cat ${local_scp_dir}/total_time_comput.txt | awk 'BEGIN {max = 0} {if ($1+0>max+0) max=$1 fi} END {print max}'`
total_time_wav=`cat ${local_scp_dir}/total_time_wav.txt | awk '{sum +=$1};END {print sum}'`
total_time_comput=`cat ${logs_outputs_dir}/total_time_comput.txt | awk 'BEGIN {max = 0} {if ($1+0>max+0) max=$1 fi} END {print max}'`
total_time_wav=`cat ${logs_outputs_dir}/total_time_wav.txt | awk '{sum +=$1};END {print sum}'`
rtf=`awk 'BEGIN{printf "%.5f\n",'$total_time_comput'/'$total_time_wav'}'`
speed=`awk 'BEGIN{printf "%.2f\n",1/'$rtf'}'`