From 1d1ef01b4e23630a99a3be7e9d1dce9550a793e9 Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期四, 11 五月 2023 16:26:24 +0800
Subject: [PATCH] Merge branch 'main' into dev_smohan
---
egs/alimeeting/sa-asr/asr_local.sh | 146 ++++++++----------------------------------------
1 files changed, 24 insertions(+), 122 deletions(-)
diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh
index 389bb01..30401b9 100755
--- a/egs/alimeeting/sa-asr/asr_local.sh
+++ b/egs/alimeeting/sa-asr/asr_local.sh
@@ -107,8 +107,8 @@
# inference_asr_model=valid.acc.best.pth
# inference_asr_model=valid.loss.ave.pth
inference_sa_asr_model=valid.acc_spk.ave.pb
-download_model= # Download a model from Model Zoo and use it for decoding.
-
+infer_with_pretrained_model=false # Use pretrained model for decoding
+download_sa_asr_model= # Download the SA-ASR model from ModelScope and use it for decoding.
# [Task dependent] Set the datadir name created by local/data.sh
train_set= # Name of training set.
valid_set= # Name of validation set used for monitoring/tuning network training.
@@ -203,7 +203,8 @@
# Note that it will overwrite args in inference config.
--inference_lm # Language modle path for decoding (default="${inference_lm}").
--inference_asr_model # ASR model path for decoding (default="${inference_asr_model}").
- --download_model # Download a model from Model Zoo and use it for decoding (default="${download_model}").
+ --infer_with_pretrained_model # Use pretrained model for decoding (default="${infer_with_pretrained_model}").
+ --download_sa_asr_model= # Download the SA-ASR model from ModelScope and use it for decoding(default="${download_sa_asr_model}").
# [Task dependent] Set the datadir name created by local/data.sh
--train_set # Name of training set (required).
@@ -304,6 +305,9 @@
lm_token_type="${token_type}"
fi
+if ${infer_with_pretrained_model}; then
+ skip_train=true
+fi
# Set tag for naming of model directory
if [ -z "${asr_tag}" ]; then
@@ -1220,122 +1224,20 @@
log "Skip the training stages"
fi
+if ${infer_with_pretrained_model}; then
+ log "Use ${download_sa_asr_model} for decoding and evaluation"
+ sa_asr_exp="${expdir}/${download_sa_asr_model}"
+ mkdir -p "${sa_asr_exp}"
+
+
+ python local/download_pretrained_model_from_modelscope.py $download_sa_asr_model ${expdir}
+ inference_sa_asr_model="model.pb"
+ inference_config=${sa_asr_exp}/decoding.yaml
+fi
if ! "${skip_eval}"; then
if [ ${stage} -le 13 ] && [ ${stop_stage} -ge 13 ]; then
- log "Stage 13: Decoding multi-talker ASR: training_dir=${asr_exp}"
-
- if ${gpu_inference}; then
- _cmd="${cuda_cmd}"
- inference_nj=$[${ngpu}*${njob_infer}]
- _ngpu=1
-
- else
- _cmd="${decode_cmd}"
- inference_nj=$inference_nj
- _ngpu=0
- fi
-
- _opts=
- if [ -n "${inference_config}" ]; then
- _opts+="--config ${inference_config} "
- fi
- if "${use_lm}"; then
- if "${use_word_lm}"; then
- _opts+="--word_lm_train_config ${lm_exp}/config.yaml "
- _opts+="--word_lm_file ${lm_exp}/${inference_lm} "
- else
- _opts+="--lm_train_config ${lm_exp}/config.yaml "
- _opts+="--lm_file ${lm_exp}/${inference_lm} "
- fi
- fi
-
- # 2. Generate run.sh
- log "Generate '${asr_exp}/${inference_tag}/run.sh'. You can resume the process from stage 13 using this script"
- mkdir -p "${asr_exp}/${inference_tag}"; echo "${run_args} --stage 13 \"\$@\"; exit \$?" > "${asr_exp}/${inference_tag}/run.sh"; chmod +x "${asr_exp}/${inference_tag}/run.sh"
-
- for dset in ${test_sets}; do
- _data="${data_feats}/${dset}"
- _dir="${asr_exp}/${inference_tag}/${dset}"
- _logdir="${_dir}/logdir"
- mkdir -p "${_logdir}"
-
- _feats_type="$(<${_data}/feats_type)"
- if [ "${_feats_type}" = raw ]; then
- _scp=wav.scp
- if [[ "${audio_format}" == *ark* ]]; then
- _type=kaldi_ark
- else
- _type=sound
- fi
- else
- _scp=feats.scp
- _type=kaldi_ark
- fi
-
- # 1. Split the key file
- key_file=${_data}/${_scp}
- split_scps=""
- _nj=$(min "${inference_nj}" "$(<${key_file} wc -l)")
- echo $_nj
- for n in $(seq "${_nj}"); do
- split_scps+=" ${_logdir}/keys.${n}.scp"
- done
- # shellcheck disable=SC2086
- utils/split_scp.pl "${key_file}" ${split_scps}
-
- # 2. Submit decoding jobs
- log "Decoding started... log: '${_logdir}/asr_inference.*.log'"
-
- ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
- python -m funasr.bin.asr_inference_launch \
- --batch_size 1 \
- --mc True \
- --nbest 1 \
- --ngpu "${_ngpu}" \
- --njob ${njob_infer} \
- --gpuid_list ${device} \
- --data_path_and_name_and_type "${_data}/${_scp},speech,${_type}" \
- --key_file "${_logdir}"/keys.JOB.scp \
- --asr_train_config "${asr_exp}"/config.yaml \
- --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
- --output_dir "${_logdir}"/output.JOB \
- --mode asr \
- ${_opts}
-
- # 3. Concatenates the output files from each jobs
- for f in token token_int score text; do
- for i in $(seq "${_nj}"); do
- cat "${_logdir}/output.${i}/1best_recog/${f}"
- done | LC_ALL=C sort -k1 >"${_dir}/${f}"
- done
- done
- fi
-
-
- if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then
- log "Stage 14: Scoring multi-talker ASR"
-
- for dset in ${test_sets}; do
- _data="${data_feats}/${dset}"
- _dir="${asr_exp}/${inference_tag}/${dset}"
-
- sed 's/\$//g' ${_data}/text > ${_data}/text_nosrc
- sed 's/\$//g' ${_dir}/text > ${_dir}/text_nosrc
-
- python utils/proce_text.py ${_data}/text_nosrc ${_data}/text.proc
- python utils/proce_text.py ${_dir}/text_nosrc ${_dir}/text.proc
-
- python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
- tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
- cat ${_dir}/text.cer.txt
-
- done
-
- fi
-
- if [ ${stage} -le 15 ] && [ ${stop_stage} -ge 15 ]; then
- log "Stage 15: Decoding SA-ASR (oracle profile): training_dir=${sa_asr_exp}"
+ log "Stage 13: Decoding SA-ASR (oracle profile): training_dir=${sa_asr_exp}"
if ${gpu_inference}; then
_cmd="${cuda_cmd}"
@@ -1426,8 +1328,8 @@
done
fi
- if [ ${stage} -le 16 ] && [ ${stop_stage} -ge 16 ]; then
- log "Stage 16: Scoring SA-ASR (oracle profile)"
+ if [ ${stage} -le 14 ] && [ ${stop_stage} -ge 14 ]; then
+ log "Stage 14: Scoring SA-ASR (oracle profile)"
for dset in ${test_sets}; do
_data="${data_feats}/${dset}"
@@ -1454,8 +1356,8 @@
fi
- if [ ${stage} -le 17 ] && [ ${stop_stage} -ge 17 ]; then
- log "Stage 17: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
+ if [ ${stage} -le 15 ] && [ ${stop_stage} -ge 15 ]; then
+ log "Stage 15: Decoding SA-ASR (cluster profile): training_dir=${sa_asr_exp}"
if ${gpu_inference}; then
_cmd="${cuda_cmd}"
@@ -1545,8 +1447,8 @@
done
fi
- if [ ${stage} -le 18 ] && [ ${stop_stage} -ge 18 ]; then
- log "Stage 18: Scoring SA-ASR (cluster profile)"
+ if [ ${stage} -le 16 ] && [ ${stop_stage} -ge 16 ]; then
+ log "Stage 16: Scoring SA-ASR (cluster profile)"
for dset in ${test_sets}; do
_data="${data_feats}/${dset}"
--
Gitblit v1.9.1