From f5bd371837cc3b89e6d387ecc84469a0e513fbd6 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 19 七月 2023 22:34:52 +0800
Subject: [PATCH] update
---
egs/callhome/eend_ola/local/infer.py | 4 ++--
egs/callhome/eend_ola/run.sh | 24 ++++++++++++++----------
egs/callhome/eend_ola/run_test.sh | 5 ++++-
funasr/models/e2e_diar_eend_ola.py | 3 +--
4 files changed, 21 insertions(+), 15 deletions(-)
diff --git a/egs/callhome/eend_ola/local/infer.py b/egs/callhome/eend_ola/local/infer.py
index 1322468..23e1d52 100644
--- a/egs/callhome/eend_ola/local/infer.py
+++ b/egs/callhome/eend_ola/local/infer.py
@@ -54,7 +54,7 @@
parser.add_argument(
"--sampling_rate",
type=int,
- default=10,
+ default=8000,
help="sampling rate",
)
parser.add_argument(
@@ -104,7 +104,7 @@
print("Start inference")
with open(args.output_rttm_file, "w") as wf:
for wav_id in wav_items.keys():
- print("Process wav: {}\n".format(wav_id))
+ print("Process wav: {}".format(wav_id))
data, rate = sf.read(wav_items[wav_id])
speech = eend_ola_feature.stft(data, args.frame_size, args.frame_shift)
speech = eend_ola_feature.transform(speech)
diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh
index b4f2739..40fb041 100644
--- a/egs/callhome/eend_ola/run.sh
+++ b/egs/callhome/eend_ola/run.sh
@@ -245,13 +245,17 @@
python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
fi
-## inference
-#if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
-# echo "Inference"
-# mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
-# CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \
-# --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \
-# --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \
-# --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
-# --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
-#fi
\ No newline at end of file
+# inference and compute DER
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "Inference"
+ mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
+ CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \
+ --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \
+ --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \
+ --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
+ --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \
+ 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
+ md-eval.pl -c 0.25 \
+ -r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \
+ -s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit
+fi
\ No newline at end of file
diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh
index c198e73..9173e6f 100644
--- a/egs/callhome/eend_ola/run_test.sh
+++ b/egs/callhome/eend_ola/run_test.sh
@@ -245,7 +245,7 @@
python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
fi
-# inference
+# inference and compute DER
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
echo "Inference"
mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
@@ -255,4 +255,7 @@
--output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
--wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \
1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
+ md-eval.pl -c 0.25 \
+ -r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \
+ -s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit
fi
\ No newline at end of file
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index fda24e2..0225a7a 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -157,12 +157,11 @@
def estimate_sequential(self,
speech: torch.Tensor,
- speech_lengths: torch.Tensor,
n_speakers: int = None,
shuffle: bool = True,
threshold: float = 0.5,
**kwargs):
- speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
+ speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
emb = self.forward_encoder(speech, speech_lengths)
if shuffle:
orders = [np.arange(e.shape[0]) for e in emb]
--
Gitblit v1.9.1