From fe5c955063e561eb4cd14e050f15dc43e79f20ec Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期四, 23 二月 2023 22:28:45 +0800
Subject: [PATCH] sond pipeline

---
 egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py |   34 ++++++++++++++++++++++++++++------
 1 files changed, 28 insertions(+), 6 deletions(-)

diff --git a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py b/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
index ff72086..1d6f53e 100644
--- a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
+++ b/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
@@ -31,7 +31,11 @@
         return task_list, None, args
 
     def post(self, result_list, args):
-        pass
+        count = [0, 0]
+        for result in result_list:
+            count[0] += result[0]
+            count[1] += result[1]
+        print("Found {} speakers, extracted {}.".format(count[1], count[0]))
 
 
 # SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
@@ -59,18 +63,28 @@
         if not in_turn and label[i]:
             st, in_turn = i, True
             spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
-        if in_turn and not label[i]:
-            in_turn = False
-            turns.append([st, i, spk])
+        if in_turn:
+            if not label[i]:
+                in_turn = False
+                turns.append([st, i, spk])
+            elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
+                turns.append([st, i, spk])
+                st, in_turn = i, True
+                spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
+    if in_turn:
+        turns.append([st, len(label), spk])
     return turns
 
 
 def process(task_args):
     task_id, task_list, _, args = task_args
+    spk_count = [0, 0]
     for mid, wav_path, rttm_path in task_list:
-        wav = librosa.load(wav_path, args.sr)[0] * 32767
+        wav, sr = sf.read(wav_path, dtype="int16")
+        assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
         multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num)
         turns = get_nonoverlap_turns(multi_label, spk_list)
+        extracted_spk = []
         count = 1
         for st, ed, spk in tqdm(turns, total=len(turns), ascii=True):
             if (ed - st) >= args.min_dur * args.sr:
@@ -80,7 +94,15 @@
                     os.makedirs(os.path.dirname(save_path))
                 sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
                 count += 1
-    return None
+                if spk not in extracted_spk:
+                    extracted_spk.append(spk)
+        if len(extracted_spk) != len(spk_list):
+            print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
+                mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
+            ))
+        spk_count[0] += len(extracted_spk)
+        spk_count[1] += len(spk_list)
+    return spk_count
 
 
 if __name__ == '__main__':

--
Gitblit v1.9.1