From af6740a2207840a772261b8a033ab9996f862529 Mon Sep 17 00:00:00 2001
From: smohan-speech <smohan@mail.ustc.edu.cn>
Date: 星期一, 08 五月 2023 16:13:23 +0800
Subject: [PATCH] add speaker-attributed ASR task for alimeeting
---
/dev/null | 32 ----------
egs/alimeeting/sa-asr/asr_local.sh | 15 +++-
funasr/bin/asr_inference.py | 9 --
funasr/losses/label_smoothing_loss.py | 2
egs/alimeeting/sa-asr/README.md | 79 ++++++++++++++++++++++++++
egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh | 24 ++++----
funasr/bin/sa_asr_inference.py | 9 --
egs/alimeeting/sa-asr/run_m2met_2023_infer.sh | 2
egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh | 4
egs/alimeeting/sa-asr/run.sh | 1
10 files changed, 108 insertions(+), 69 deletions(-)
diff --git a/egs/alimeeting/sa-asr/README.md b/egs/alimeeting/sa-asr/README.md
new file mode 100644
index 0000000..882345c
--- /dev/null
+++ b/egs/alimeeting/sa-asr/README.md
@@ -0,0 +1,79 @@
+# Get Started
+Speaker Attributed Automatic Speech Recognition (SA-ASR) is a task proposed to solve "who spoke what". Specifically, the goal of SA-ASR is not only to obtain multi-speaker transcriptions, but also to identify the corresponding speaker for each utterance. The method used in this example is referenced in the paper: [End-to-End Speaker-Attributed ASR with Transformer](https://www.isca-speech.org/archive/pdfs/interspeech_2021/kanda21b_interspeech.pdf).
+To run this receipe, first you need to install FunASR and ModelScope. ([installation](https://alibaba-damo-academy.github.io/FunASR/en/installation.html))
+There are two startup scripts, `run.sh` for training and evaluating on the old eval and test sets, and `run_m2met_2023_infer.sh` for inference on the new test set of the Multi-Channel Multi-Party Meeting Transcription 2.0 ([M2MET2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)) Challenge.
+Before running `run.sh`, you must manually download and unpack the [AliMeeting](http://www.openslr.org/119/) corpus and place it in the `./dataset` directory:
+```shell
+dataset
+|鈥斺�� Eval_Ali_far
+|鈥斺�� Eval_Ali_near
+|鈥斺�� Test_Ali_far
+|鈥斺�� Test_Ali_near
+|鈥斺�� Train_Ali_far
+|鈥斺�� Train_Ali_near
+```
+There are 18 stages in `run.sh`:
+```shell
+stage 1 - 5: Data preparation and processing.
+stage 6: Generate speaker profiles (Stage 6 takes a lot of time).
+stage 7 - 9: Language model training (Optional).
+stage 10 - 11: ASR training (SA-ASR requires loading the pre-trained ASR model).
+stage 12: SA-ASR training.
+stage 13 - 18: Inference and evaluation.
+```
+Before running `run_m2met_2023_infer.sh`, you need to place the new test set `Test_2023_Ali_far` (to be released after the challenge starts) in the `./dataset` directory, which contains only raw audios. Then put the given `wav.scp`, `wav_raw.scp`, `segments`, `utt2spk` and `spk2utt` in the `./data/Test_2023_Ali_far` directory.
+```shell
+data/Test_2023_Ali_far
+|鈥斺�� wav.scp
+|鈥斺�� wav_raw.scp
+|鈥斺�� segments
+|鈥斺�� utt2spk
+|鈥斺�� spk2utt
+```
+There are 4 stages in `run_m2met_2023_infer.sh`:
+```shell
+stage 1: Data preparation and processing.
+stage 2: Generate speaker profiles for inference.
+stage 3: Inference.
+stage 4: Generation of SA-ASR results required for final submission.
+```
+# Format of Final Submission
+Finally, you need to submit a file called `text_spk_merge` with the following format:
+```shell
+Meeting_1 text_spk_1_A$text_spk_1_B$text_spk_1_C ...
+Meeting_2 text_spk_2_A$text_spk_2_B$text_spk_2_C ...
+...
+```
+Here, text_spk_1_A represents the full transcription of speaker_A of Meeting_1 (merged in chronological order), and $ represents the separator symbol. There's no need to worry about the speaker permutation as the optimal permutation will be computed in the end. For more information, please refer to the results generated after executing the baseline code.
+# Baseline Results
+The results of the baseline system are as follows. The baseline results include speaker independent character error rate (SI-CER) and concatenated minimum permutation character error rate (cpCER), the former is speaker independent and the latter is speaker dependent. The speaker profile adopts the oracle speaker embedding during training. However, due to the lack of oracle speaker label during evaluation, the speaker profile provided by an additional spectral clustering is used. Meanwhile, the results of using the oracle speaker profile on Eval and Test Set are also provided to show the impact of speaker profile accuracy.
+<table>
+ <tr >
+ <td rowspan="2"></td>
+ <td colspan="2">SI-CER(%)</td>
+ <td colspan="2">cpCER(%)</td>
+ </tr>
+ <tr>
+ <td>Eval</td>
+ <td>Test</td>
+ <td>Eval</td>
+ <td>Test</td>
+ </tr>
+ <tr>
+ <td>oracle profile</td>
+ <td>31.93</td>
+ <td>32.75</td>
+ <td>48.56</td>
+ <td>53.33</td>
+ </tr>
+ <tr>
+ <td>cluster profile</td>
+ <td>31.94</td>
+ <td>32.77</td>
+ <td>55.49</td>
+ <td>58.17</td>
+ </tr>
+</table>
+
+# Reference
+N. Kanda, G. Ye, Y. Gaur, X. Wang, Z. Meng, Z. Chen, and T. Yoshioka, "End-to-end speaker-attributed ASR with transformer," in Interspeech. ISCA, 2021, pp. 4413鈥�4417.
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh
index 419e341..f8cdcd3 100755
--- a/egs/alimeeting/sa-asr/asr_local.sh
+++ b/egs/alimeeting/sa-asr/asr_local.sh
@@ -475,7 +475,9 @@
fi
local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
- cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
+ if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
+ cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
+ fi
rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
_opts=
@@ -568,8 +570,11 @@
# generate uttid
cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
- # filter utt2spk_all_fifo
- python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
+
+ if [ "${dset}" = "Train_Ali_far" ] || [ "${dset}" = "Eval_Ali_far" ] || [ "${dset}" = "Test_Ali_far" ]; then
+ # filter utt2spk_all_fifo
+ python local/filter_utt2spk_all_fifo.py ${data_feats}/${dset}/uttid ${data_feats}/org/${dset} ${data_feats}/${dset}
+ fi
done
# shellcheck disable=SC2002
@@ -585,7 +590,7 @@
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- local/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${data_feats}/lm_train.txt | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
@@ -603,6 +608,7 @@
python local/process_text_id.py ${data_feats}/${dset}
log "Successfully generate ${data_feats}/${dset}/text_id_train"
# generate oracle_embedding from single-speaker audio segment
+ log "oracle_embedding is being generated in the background, and the log is profile_log/gen_oracle_embedding_${dset}.log"
python local/gen_oracle_embedding.py "${data_feats}/${dset}" "data/local/${dset}_correct_single_speaker" &> "profile_log/gen_oracle_embedding_${dset}.log"
log "Successfully generate oracle embedding for ${dset} (${data_feats}/${dset}/oracle_embedding.scp)"
# generate oracle_profile and cluster_profile from oracle_embedding and cluster_embedding (padding the speaker during training)
@@ -615,6 +621,7 @@
fi
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
if [ "${dset}" = "${valid_set}" ] || [ "${dset}" = "${test_sets}" ]; then
+ log "cluster_profile is being generated in the background, and the log is profile_log/gen_cluster_profile_infer_${dset}.log"
python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
fi
diff --git a/egs/alimeeting/sa-asr/asr_local_infer.sh b/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh
similarity index 98%
rename from egs/alimeeting/sa-asr/asr_local_infer.sh
rename to egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh
index b7a9289..a23215c 100755
--- a/egs/alimeeting/sa-asr/asr_local_infer.sh
+++ b/egs/alimeeting/sa-asr/asr_local_m2met_2023_infer.sh
@@ -449,7 +449,7 @@
_opts+="--segments data/${dset}/segments "
fi
# shellcheck disable=SC2086
- scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
--audio-format "${audio_format}" --fs "${fs}" ${_opts} \
"data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
@@ -467,7 +467,7 @@
mkdir -p "profile_log"
for dset in "${test_sets}"; do
# generate cluster_profile with spectral-cluster directly (for infering and without oracle information)
- python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/local/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
+ python local/gen_cluster_profile_infer.py "${data_feats}/${dset}" "data/${dset}" 0.996 0.815 &> "profile_log/gen_cluster_profile_infer_${dset}.log"
log "Successfully generate cluster profile for ${dset} (${data_feats}/${dset}/cluster_profile_infer.scp)"
done
fi
diff --git a/egs/alimeeting/sa-asr/local/compute_wer.py b/egs/alimeeting/sa-asr/local/compute_wer.py
deleted file mode 100755
index 349a3f6..0000000
--- a/egs/alimeeting/sa-asr/local/compute_wer.py
+++ /dev/null
@@ -1,157 +0,0 @@
-import os
-import numpy as np
-import sys
-
-def compute_wer(ref_file,
- hyp_file,
- cer_detail_file):
- rst = {
- 'Wrd': 0,
- 'Corr': 0,
- 'Ins': 0,
- 'Del': 0,
- 'Sub': 0,
- 'Snt': 0,
- 'Err': 0.0,
- 'S.Err': 0.0,
- 'wrong_words': 0,
- 'wrong_sentences': 0
- }
-
- hyp_dict = {}
- ref_dict = {}
- with open(hyp_file, 'r') as hyp_reader:
- for line in hyp_reader:
- key = line.strip().split()[0]
- value = line.strip().split()[1:]
- hyp_dict[key] = value
- with open(ref_file, 'r') as ref_reader:
- for line in ref_reader:
- key = line.strip().split()[0]
- value = line.strip().split()[1:]
- ref_dict[key] = value
-
- cer_detail_writer = open(cer_detail_file, 'w')
- for hyp_key in hyp_dict:
- if hyp_key in ref_dict:
- out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
- rst['Wrd'] += out_item['nwords']
- rst['Corr'] += out_item['cor']
- rst['wrong_words'] += out_item['wrong']
- rst['Ins'] += out_item['ins']
- rst['Del'] += out_item['del']
- rst['Sub'] += out_item['sub']
- rst['Snt'] += 1
- if out_item['wrong'] > 0:
- rst['wrong_sentences'] += 1
- cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
- cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
- cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
-
- if rst['Wrd'] > 0:
- rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
- if rst['Snt'] > 0:
- rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
-
- cer_detail_writer.write('\n')
- cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
- ", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
- cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
- cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
-
-
-def compute_wer_by_line(hyp,
- ref):
- hyp = list(map(lambda x: x.lower(), hyp))
- ref = list(map(lambda x: x.lower(), ref))
-
- len_hyp = len(hyp)
- len_ref = len(ref)
-
- cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
-
- ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
-
- for i in range(len_hyp + 1):
- cost_matrix[i][0] = i
- for j in range(len_ref + 1):
- cost_matrix[0][j] = j
-
- for i in range(1, len_hyp + 1):
- for j in range(1, len_ref + 1):
- if hyp[i - 1] == ref[j - 1]:
- cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
- else:
- substitution = cost_matrix[i - 1][j - 1] + 1
- insertion = cost_matrix[i - 1][j] + 1
- deletion = cost_matrix[i][j - 1] + 1
-
- compare_val = [substitution, insertion, deletion]
-
- min_val = min(compare_val)
- operation_idx = compare_val.index(min_val) + 1
- cost_matrix[i][j] = min_val
- ops_matrix[i][j] = operation_idx
-
- match_idx = []
- i = len_hyp
- j = len_ref
- rst = {
- 'nwords': len_ref,
- 'cor': 0,
- 'wrong': 0,
- 'ins': 0,
- 'del': 0,
- 'sub': 0
- }
- while i >= 0 or j >= 0:
- i_idx = max(0, i)
- j_idx = max(0, j)
-
- if ops_matrix[i_idx][j_idx] == 0: # correct
- if i - 1 >= 0 and j - 1 >= 0:
- match_idx.append((j - 1, i - 1))
- rst['cor'] += 1
-
- i -= 1
- j -= 1
-
- elif ops_matrix[i_idx][j_idx] == 2: # insert
- i -= 1
- rst['ins'] += 1
-
- elif ops_matrix[i_idx][j_idx] == 3: # delete
- j -= 1
- rst['del'] += 1
-
- elif ops_matrix[i_idx][j_idx] == 1: # substitute
- i -= 1
- j -= 1
- rst['sub'] += 1
-
- if i < 0 and j >= 0:
- rst['del'] += 1
- elif j < 0 and i >= 0:
- rst['ins'] += 1
-
- match_idx.reverse()
- wrong_cnt = cost_matrix[len_hyp][len_ref]
- rst['wrong'] = wrong_cnt
-
- return rst
-
-def print_cer_detail(rst):
- return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
- + ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
- + str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
- + ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
-
-if __name__ == '__main__':
- if len(sys.argv) != 4:
- print("usage : python compute-wer.py test.ref test.hyp test.wer")
- sys.exit(0)
-
- ref_file = sys.argv[1]
- hyp_file = sys.argv[2]
- cer_detail_file = sys.argv[3]
- compute_wer(ref_file, hyp_file, cer_detail_file)
diff --git a/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
index 9e08dba..1022ae6 100755
--- a/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
+++ b/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
@@ -63,20 +63,20 @@
fi
-<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
- utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
+<"${srcdir}"/utt2spk local/apply_map.pl -f 1 "${destdir}"/utt_map | \
+ local/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
-utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
+local/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
if [[ -f ${srcdir}/segments ]]; then
- utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
- utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
+ local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
+ local/apply_map.pl -f 2 "${destdir}"/reco_map | \
awk -v factor="${factor}" \
'{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
>"${destdir}"/segments
- utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+ local/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
awk -v factor="${factor}" \
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
@@ -84,13 +84,13 @@
else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
> "${destdir}"/wav.scp
if [[ -f ${srcdir}/reco2file_and_channel ]]; then
- utils/apply_map.pl -f 1 "${destdir}"/reco_map \
+ local/apply_map.pl -f 1 "${destdir}"/reco_map \
<"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
fi
else # no segments->wav indexed by utterance.
if [[ -f ${srcdir}/wav.scp ]]; then
- utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+ local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
# Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
awk -v factor="${factor}" \
'{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
@@ -101,16 +101,16 @@
fi
if [[ -f ${srcdir}/text ]]; then
- utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
+ local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
fi
if [[ -f ${srcdir}/spk2gender ]]; then
- utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
+ local/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
fi
if [[ -f ${srcdir}/utt2lang ]]; then
- utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
+ local/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
fi
rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
-utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
+local/validate_data_dir.sh --no-feats --no-text "${destdir}"
diff --git a/egs/alimeeting/sa-asr/local/proce_text.py b/egs/alimeeting/sa-asr/local/proce_text.py
deleted file mode 100755
index e56cc0f..0000000
--- a/egs/alimeeting/sa-asr/local/proce_text.py
+++ /dev/null
@@ -1,32 +0,0 @@
-
-import sys
-import re
-
-in_f = sys.argv[1]
-out_f = sys.argv[2]
-
-
-with open(in_f, "r", encoding="utf-8") as f:
- lines = f.readlines()
-
-with open(out_f, "w", encoding="utf-8") as f:
- for line in lines:
- outs = line.strip().split(" ", 1)
- if len(outs) == 2:
- idx, text = outs
- text = re.sub("</s>", "", text)
- text = re.sub("<s>", "", text)
- text = re.sub("@@", "", text)
- text = re.sub("@", "", text)
- text = re.sub("<unk>", "", text)
- text = re.sub(" ", "", text)
- text = re.sub("\$", "", text)
- text = text.lower()
- else:
- idx = outs[0]
- text = " "
-
- text = [x for x in text]
- text = " ".join(text)
- out = "{} {}\n".format(idx, text)
- f.write(out)
diff --git a/egs/alimeeting/sa-asr/run_m2met_2023.sh b/egs/alimeeting/sa-asr/run.sh
similarity index 97%
rename from egs/alimeeting/sa-asr/run_m2met_2023.sh
rename to egs/alimeeting/sa-asr/run.sh
index 807e499..e5297b8 100755
--- a/egs/alimeeting/sa-asr/run_m2met_2023.sh
+++ b/egs/alimeeting/sa-asr/run.sh
@@ -8,7 +8,6 @@
ngpu=4
device="0,1,2,3"
-#stage 1 creat both near and far
stage=1
stop_stage=18
diff --git a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
index d35e6a6..1967864 100755
--- a/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
+++ b/egs/alimeeting/sa-asr/run_m2met_2023_infer.sh
@@ -22,7 +22,7 @@
lm_config=conf/train_lm_transformer.yaml
use_lm=false
use_wordlm=false
-./asr_local_infer.sh \
+./asr_local_m2met_2023_infer.sh \
--device ${device} \
--ngpu ${ngpu} \
--stage ${stage} \
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index 18f0add..a52e94a 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -94,7 +94,7 @@
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
if asr_train_args.frontend=='wav_frontend':
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
else:
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@@ -146,13 +146,6 @@
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
- logging.info(f"Beam_search: {beam_search}")
- logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py
index ec575df..c894f54 100644
--- a/funasr/bin/sa_asr_inference.py
+++ b/funasr/bin/sa_asr_inference.py
@@ -89,7 +89,7 @@
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
if asr_train_args.frontend=='wav_frontend':
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
else:
frontend_class=frontend_choices.get_class(asr_train_args.frontend)
frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@@ -141,13 +141,6 @@
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
- logging.info(f"Beam_search: {beam_search}")
- logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py
index 8f63df9..3ea34c0 100644
--- a/funasr/losses/label_smoothing_loss.py
+++ b/funasr/losses/label_smoothing_loss.py
@@ -97,7 +97,7 @@
normalize_length=False,
criterion=nn.NLLLoss(reduction='none'),
):
- """Construct an LabelSmoothingLoss object."""
+ """Construct an NllLoss object."""
super(NllLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
--
Gitblit v1.9.1