From f5bd371837cc3b89e6d387ecc84469a0e513fbd6 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 19 七月 2023 22:34:52 +0800
Subject: [PATCH] update

---
 egs/callhome/eend_ola/local/infer.py |    4 ++--
 egs/callhome/eend_ola/run.sh         |   24 ++++++++++++++----------
 egs/callhome/eend_ola/run_test.sh    |    5 ++++-
 funasr/models/e2e_diar_eend_ola.py   |    3 +--
 4 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/egs/callhome/eend_ola/local/infer.py b/egs/callhome/eend_ola/local/infer.py
index 1322468..23e1d52 100644
--- a/egs/callhome/eend_ola/local/infer.py
+++ b/egs/callhome/eend_ola/local/infer.py
@@ -54,7 +54,7 @@
     parser.add_argument(
         "--sampling_rate",
         type=int,
-        default=10,
+        default=8000,
         help="sampling rate",
     )
     parser.add_argument(
@@ -104,7 +104,7 @@
     print("Start inference")
     with open(args.output_rttm_file, "w") as wf:
         for wav_id in wav_items.keys():
-            print("Process wav: {}\n".format(wav_id))
+            print("Process wav: {}".format(wav_id))
             data, rate = sf.read(wav_items[wav_id])
             speech = eend_ola_feature.stft(data, args.frame_size, args.frame_shift)
             speech = eend_ola_feature.transform(speech)
diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh
index b4f2739..40fb041 100644
--- a/egs/callhome/eend_ola/run.sh
+++ b/egs/callhome/eend_ola/run.sh
@@ -245,13 +245,17 @@
     python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
 fi
 
-## inference
-#if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
-#    echo "Inference"
-#    mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
-#    CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \
-#        --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \
-#        --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \
-#        --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
-#        --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} 1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
-#fi
\ No newline at end of file
+# inference and compute DER
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+    echo "Inference"
+    mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
+    CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python local/infer.py \
+        --config_file ${exp_dir}/exp/${callhome_model_dir}/config.yaml \
+        --model_file ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb \
+        --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
+        --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \
+        1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
+    md-eval.pl -c 0.25 \
+          -r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \
+          -s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit
+fi
\ No newline at end of file
diff --git a/egs/callhome/eend_ola/run_test.sh b/egs/callhome/eend_ola/run_test.sh
index c198e73..9173e6f 100644
--- a/egs/callhome/eend_ola/run_test.sh
+++ b/egs/callhome/eend_ola/run_test.sh
@@ -245,7 +245,7 @@
     python local/model_averaging.py ${exp_dir}/exp/${callhome_model_dir}/$callhome_ave_id.pb $models
 fi
 
-# inference
+# inference and compute DER
 if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
     echo "Inference"
     mkdir -p ${exp_dir}/exp/${callhome_model_dir}/inference/log
@@ -255,4 +255,7 @@
         --output_rttm_file ${exp_dir}/exp/${callhome_model_dir}/inference/rttm \
         --wav_scp_file ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/${callhome2_wav_scp_file} \
         1> ${exp_dir}/exp/${callhome_model_dir}/inference/log/infer.log 2>&1
+    md-eval.pl -c 0.25 \
+          -r ${callhome_feats_dir_chunk2000}/${callhome_valid_dataset}/rttm \
+          -s ${exp_dir}/exp/${callhome_model_dir}/inference/rttm > ${exp_dir}/exp/${callhome_model_dir}/inference/result_med11_collar0.25 2>/dev/null || exit
 fi
\ No newline at end of file
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index fda24e2..0225a7a 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -157,12 +157,11 @@
 
     def estimate_sequential(self,
                             speech: torch.Tensor,
-                            speech_lengths: torch.Tensor,
                             n_speakers: int = None,
                             shuffle: bool = True,
                             threshold: float = 0.5,
                             **kwargs):
-        speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
+        speech_lengths = torch.tensor([len(sph) for sph in speech]).to(torch.int64)
         emb = self.forward_encoder(speech, speech_lengths)
         if shuffle:
             orders = [np.arange(e.shape[0]) for e in emb]

--
Gitblit v1.9.1