From 3c173e6187ea49520b1e946c9d5de66c824f0864 Mon Sep 17 00:00:00 2001
From: 语帆 <yf352572@alibaba-inc.com>
Date: 星期三, 21 二月 2024 17:02:15 +0800
Subject: [PATCH] Merge remote-tracking branch 'origin/main' into dev_yf
---
examples/aishell/conformer/run.sh | 22 +++++++++++++---------
1 files changed, 13 insertions(+), 9 deletions(-)
diff --git a/examples/aishell/conformer/run.sh b/examples/aishell/conformer/run.sh
index ff99f9e..f5d993a 100755
--- a/examples/aishell/conformer/run.sh
+++ b/examples/aishell/conformer/run.sh
@@ -5,7 +5,7 @@
# general configuration
feats_dir="../DATA" #feature output dictionary
-exp_dir="."
+exp_dir=`pwd`
lang=zh
token_type=char
stage=0
@@ -14,10 +14,10 @@
# feature configuration
nj=32
-inference_device="cuda" #"cpu"
+inference_device="cuda" #"cpu", "cuda:0", "cuda:1"
inference_checkpoint="model.pt"
inference_scp="wav.scp"
-inference_batch_size=32
+inference_batch_size=1
# data
raw_data=../raw_data
@@ -109,6 +109,7 @@
log_file="${exp_dir}/exp/${model_dir}/train.log.txt.${current_time}"
echo "log_file: ${log_file}"
+ export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
torchrun \
--nnodes 1 \
@@ -129,7 +130,7 @@
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "stage 5: Inference"
- if ${inference_device} == "cuda"; then
+ if [ ${inference_device} == "cuda" ]; then
nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
else
inference_batch_size=1
@@ -141,8 +142,9 @@
for dset in ${test_sets}; do
- inference_dir="${exp_dir}/exp/${model_dir}/${inference_checkpoint}/${dset}"
+ inference_dir="${exp_dir}/exp/${model_dir}/inference-${inference_checkpoint}/${dset}"
_logdir="${inference_dir}/logdir"
+ echo "inference_dir: ${inference_dir}"
mkdir -p "${_logdir}"
data_dir="${feats_dir}/data/${dset}"
@@ -154,7 +156,7 @@
done
utils/split_scp.pl "${key_file}" ${split_scps}
- gpuid_list_array=(${gpuid_list//,/ })
+ gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
for JOB in $(seq ${nj}); do
{
id=$((JOB-1))
@@ -170,7 +172,9 @@
++input="${_logdir}/keys.${JOB}.scp" \
++output_dir="${inference_dir}/${JOB}" \
++device="${inference_device}" \
- ++batch_size="${inference_batch_size}"
+ ++ncpu=1 \
+ ++disable_log=true \
+ ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
}&
done
@@ -186,8 +190,8 @@
done
echo "Computing WER ..."
- cp ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc
- cp ${data_dir}/text ${inference_dir}/1best_recog/text.ref
+ python utils/postprocess_text_zh.py ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc
+ python utils/postprocess_text_zh.py ${data_dir}/text ${inference_dir}/1best_recog/text.ref
python utils/compute_wer.py ${inference_dir}/1best_recog/text.ref ${inference_dir}/1best_recog/text.proc ${inference_dir}/1best_recog/text.cer
tail -n 3 ${inference_dir}/1best_recog/text.cer
done
--
Gitblit v1.9.1