From 3cd3473bf7a3b41484baa86d9092248d78e7af39 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 21 四月 2023 17:17:37 +0800
Subject: [PATCH] docs

---
 egs/mars/sd/scripts/dump_rttm_to_labels.py |   26 +++++++++++++++++++++++---
 1 files changed, 23 insertions(+), 3 deletions(-)

diff --git a/egs/mars/sd/scripts/dump_rttm_to_labels.py b/egs/mars/sd/scripts/dump_rttm_to_labels.py
index b4508ec..ec1c765 100644
--- a/egs/mars/sd/scripts/dump_rttm_to_labels.py
+++ b/egs/mars/sd/scripts/dump_rttm_to_labels.py
@@ -31,9 +31,29 @@
         for scp_path in wav_scp_list:
             meeting_scp.update(load_scp_as_dict(scp_path))
 
-        assert len(meeting_scp) == len(meeting2rttm), \
-            "Number of wav and rttm mismatch {} != {}".format(len(meeting_scp), len(meeting2rttm))
+        if len(meeting_scp) != len(meeting2rttm):
+            logging.warning("Number of wav and rttm mismatch {} != {}".format(
+                len(meeting_scp), len(meeting2rttm)))
+            common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys())
+            logging.warning("Keep {} records.".format(len(common_keys)))
+            new_meeting_scp = OrderedDict()
+            rm_keys = []
+            for key in meeting_scp:
+                if key not in common_keys:
+                    rm_keys.append(key)
+                else:
+                    new_meeting_scp[key] = meeting_scp[key]
+            logging.warning("Keys are removed from wav scp: {}".format(" ".join(rm_keys)))
 
+            new_meeting2rttm = OrderedDict()
+            rm_keys = []
+            for key in meeting2rttm:
+                if key not in common_keys:
+                    rm_keys.append(key)
+                else:
+                    new_meeting2rttm[key] = meeting2rttm[key]
+            logging.warning("Keys are removed from rttm scp: {}".format(" ".join(rm_keys)))
+            meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm
         if not os.path.exists(args.out_dir):
             os.makedirs(args.out_dir)
 
@@ -59,7 +79,7 @@
                 sr=None, frame_shift=0.01):
     frame_shift = int(frame_shift * sr)
     num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
-    multi_label = np.zeros([n_spk, num_frame], dtype=int)
+    multi_label = np.zeros([n_spk, num_frame], dtype=np.float32)
     for _, st, dur, spk in spk_turns:
         idx = spk_list.index(spk)
 

--
Gitblit v1.9.1