From dc5367bbf12ad99a0df242506429f33554ccdea5 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 三月 2023 14:04:49 +0800
Subject: [PATCH] rtf benchmark

---
 funasr/export/README.md                                               |    3 -
 funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py     |    1 
 funasr/runtime/python/utils/test_rtf.sh                               |  109 +++++++++++++++++++++--------------
 funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py |    3 
 funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py     |    4 
 funasr/runtime/python/utils/test_rtf.py                               |   13 ++--
 6 files changed, 77 insertions(+), 56 deletions(-)

diff --git a/funasr/export/README.md b/funasr/export/README.md
index 2c9be4f..9d02b53 100644
--- a/funasr/export/README.md
+++ b/funasr/export/README.md
@@ -50,6 +50,3 @@
 ```shell
 python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch
 ```
-
-## Acknowledge
-1. We acknowledge
diff --git a/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py b/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py
index ce975f0..d47135a 100644
--- a/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py
+++ b/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py
@@ -25,6 +25,7 @@
                  plot_timestamp_to: str = "",
                  pred_bias: int = 1,
                  quantize: bool = False,
+                 intra_op_num_threads: int = 1,
                  ):
 
         if not Path(model_dir).exists():
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py
index 422cb67..61c85ec 100644
--- a/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py
+++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/paraformer_onnx.py
@@ -27,6 +27,7 @@
                  plot_timestamp_to: str = "",
                  pred_bias: int = 1,
                  quantize: bool = False,
+                 intra_op_num_threads: int = 4,
                  ):
 
         if not Path(model_dir).exists():
@@ -45,7 +46,7 @@
             cmvn_file=cmvn_file,
             **config['frontend_conf']
         )
-        self.ort_infer = OrtInferSession(model_file, device_id)
+        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
         self.batch_size = batch_size
         self.plot_timestamp_to = plot_timestamp_to
         self.pred_bias = pred_bias
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
index 392fe6b..ec907c0 100644
--- a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
+++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
@@ -147,10 +147,10 @@
 
 
 class OrtInferSession():
-    def __init__(self, model_file, device_id=-1):
+    def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
         device_id = str(device_id)
         sess_opt = SessionOptions()
-        sess_opt.intra_op_num_threads = 4
+        sess_opt.intra_op_num_threads = intra_op_num_threads
         sess_opt.log_severity_level = 4
         sess_opt.enable_cpu_mem_arena = False
         sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
diff --git a/funasr/runtime/python/utils/test_rtf.py b/funasr/runtime/python/utils/test_rtf.py
index 46f204d..fd26fad 100644
--- a/funasr/runtime/python/utils/test_rtf.py
+++ b/funasr/runtime/python/utils/test_rtf.py
@@ -8,23 +8,24 @@
 parser = argparse.ArgumentParser()
 parser.add_argument('--model_dir', type=str, required=True)
 parser.add_argument('--backend', type=str, default='onnx', help='["onnx", "torch"]')
-parser.add_argument('--wav_file', type=int, default=0, help='amp fallback number')
+parser.add_argument('--wav_file', type=str, default=None, help='amp fallback number')
 parser.add_argument('--quantize', type=bool, default=False, help='quantized model')
+parser.add_argument('--intra_op_num_threads', type=int, default=1, help='intra_op_num_threads for onnx')
 args = parser.parse_args()
 
 
-from torch_paraformer import Paraformer
-if args.backend == "onnxruntime":
-	from rapid_paraformer import Paraformer
+from funasr.runtime.python.libtorch.torch_paraformer import Paraformer
+if args.backend == "onnx":
+	from funasr.runtime.python.onnxruntime.rapid_paraformer import Paraformer
 	
-model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize)
+model = Paraformer(args.model_dir, batch_size=1, quantize=args.quantize, intra_op_num_threads=args.intra_op_num_threads)
 
 wav_file_f = open(args.wav_file, 'r')
 wav_files = wav_file_f.readlines()
 
 # warm-up
 total = 0.0
-num = 100
+num = 30
 wav_path = wav_files[0].split("\t")[1].strip() if "\t" in wav_files[0] else wav_files[0].split(" ")[1].strip()
 for i in range(num):
 	beg_time = time.time()
diff --git a/funasr/runtime/python/utils/test_rtf.sh b/funasr/runtime/python/utils/test_rtf.sh
index cada080..7399c88 100644
--- a/funasr/runtime/python/utils/test_rtf.sh
+++ b/funasr/runtime/python/utils/test_rtf.sh
@@ -1,74 +1,95 @@
 
 nj=64
-
-#:<<!
-backend=libtorch
-model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/libtorch"
-tag=${backend}_fp32
-quantize='False'
-!
-
-:<<!
-backend=libtorch
-model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/libtorch_fb20"
-tag=${backend}_amp_fb20
-quantize='True'
-!
-
-:<<!
-backend=onnxruntime
-model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/onnx"
-tag=${backend}_fp32
-quantize='False'
-!
-
-:<<!
-backend=onnxruntime
-model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/onnx_dynamic"
-tag=${backend}_fp32
-quantize='True'
-!
-
+stage=0
 scp=/nfs/haoneng.lhn/funasr_data/aishell-1/data/test/wav.scp
-local_scp_dir=/nfs/zhifu.gzf/data_debug/test/${tag}/split$nj
-
+logs_outputs_dir=/nfs/zhifu.gzf/data_debug/test/${tag}/split$nj
+split_scps_tool=../../../egs/aishell/transformer/utils/split_scp.pl
 rtf_tool=test_rtf.py
 
-mkdir -p ${local_scp_dir}
-echo ${local_scp_dir}
+##:<<!
+#backend=libtorch
+#model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/libtorch"
+#tag=${backend}_fp32
+#quantize='False'
+#!
+#
+#:<<!
+#backend=libtorch
+#model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/libtorch_fb20"
+#tag=${backend}_amp_fb20
+#quantize='True'
+#!
+#
+#:<<!
+#backend=onnxruntime
+#model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/onnx"
+#tag=${backend}_fp32
+#quantize='False'
+#!
+#
+#:<<!
+#backend=onnxruntime
+#model_dir="/nfs/zhifu.gzf/export/damo/amp_int8/onnx_dynamic"
+#tag=${backend}_fp32
+#quantize='True'
+#!
 
+#:<<!
+model_name="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+export_root="/nfs/zhifu.gzf/export"
+backend=onnx
+quantize='True'
+tag=${model_name}/${backend}_${quantize}
+!
+
+
+mkdir -p ${logs_outputs_dir}
+echo ${logs_outputs_dir}
+
+
+if [ $stage == 0 ];then
+
+  if [ $quantize == 'True' ];then
+    python -m funasr.export.export_model --model-name ${model_name} --export-dir ${export_root} --type ${backend} --quantize --audio_in ${scp}
+  else
+    python -m funasr.export.export_model --model-name ${model_name} --export-dir ${export_root} --type ${backend}
+  fi
+
+fi
+
+model_dir=${export_root}/${model_name}
 split_scps=""
 for JOB in $(seq ${nj}); do
-    split_scps="$split_scps $local_scp_dir/wav.$JOB.scp"
+    split_scps="$split_scps $logs_outputs_dir/wav.$JOB.scp"
 done
 
-perl ../../../egs/aishell/transformer/utils/split_scp.pl $scp ${split_scps}
+perl ${split_scps_tool} $scp ${split_scps}
 
 
 for JOB in $(seq ${nj}); do
   {
     core_id=`expr $JOB - 1`
-    taskset -c ${core_id} python ${rtf_tool} --backend ${backend} --model_dir ${model_dir} --wav_file ${local_scp_dir}/wav.$JOB.scp --quantize ${quantize} &> ${local_scp_dir}/log.$JOB.txt
+    taskset -c ${core_id} python ${rtf_tool} --backend ${backend} --model_dir ${model_dir} --wav_file ${logs_outputs_dir}/wav.$JOB.scp --quantize ${quantize} &> ${logs_outputs_dir}/log.$JOB.txt
   }&
 
 done
 wait
 
 
-rm -rf ${local_scp_dir}/total_time_comput.txt
-rm -rf ${local_scp_dir}/total_time_wav.txt
-rm -rf ${local_scp_dir}/total_rtf.txt
+rm -rf ${logs_outputs_dir}/total_time_comput.txt
+rm -rf ${logs_outputs_dir}/total_time_wav.txt
+rm -rf ${logs_outputs_dir}/total_rtf.txt
 for JOB in $(seq ${nj}); do
   {
-    cat ${local_scp_dir}/log.$JOB.txt | grep "total_time_comput" | awk -F ' '  '{print $2}' >> ${local_scp_dir}/total_time_comput.txt
-    cat ${local_scp_dir}/log.$JOB.txt | grep "total_time_wav" | awk -F ' '  '{print $2}' >> ${local_scp_dir}/total_time_wav.txt
-    cat ${local_scp_dir}/log.$JOB.txt | grep "total_rtf" | awk -F ' '  '{print $2}' >> ${local_scp_dir}/total_rtf.txt
+    cat ${logs_outputs_dir}/log.$JOB.txt | grep "total_time_comput" | awk -F ' '  '{print $2}' >> ${logs_outputs_dir}/total_time_comput.txt
+    cat ${logs_outputs_dir}/log.$JOB.txt | grep "total_time_wav" | awk -F ' '  '{print $2}' >> ${logs_outputs_dir}/total_time_wav.txt
+    cat ${logs_outputs_dir}/log.$JOB.txt | grep "total_rtf" | awk -F ' '  '{print $2}' >> ${logs_outputs_dir}/total_rtf.txt
   }
 
 done
 
-total_time_comput=`cat ${local_scp_dir}/total_time_comput.txt | awk 'BEGIN {max = 0} {if ($1+0>max+0) max=$1 fi} END {print max}'`
-total_time_wav=`cat ${local_scp_dir}/total_time_wav.txt | awk '{sum +=$1};END {print sum}'`
+total_time_comput=`cat ${logs_outputs_dir}/total_time_comput.txt | awk 'BEGIN {max = 0} {if ($1+0>max+0) max=$1 fi} END {print max}'`
+total_time_wav=`cat ${logs_outputs_dir}/total_time_wav.txt | awk '{sum +=$1};END {print sum}'`
 rtf=`awk 'BEGIN{printf "%.5f\n",'$total_time_comput'/'$total_time_wav'}'`
 speed=`awk 'BEGIN{printf "%.2f\n",1/'$rtf'}'`
 

--
Gitblit v1.9.1