hnluo
2023-04-17 24f73665e2d8ea8e4de2fe4f900bc539d7f7b989
egs/mars/sd/scripts/dump_rttm_to_labels.py
@@ -36,14 +36,24 @@
                len(meeting_scp), len(meeting2rttm)))
            common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys())
            logging.warning("Keep {} records.".format(len(common_keys)))
            new_meeting_scp = OrderedDict()
            rm_keys = []
            for key in meeting_scp:
                if key not in common_keys:
                    meeting_scp.pop(key)
                    logging.warning("Pop {} from wav scp".format(key))
                if key not in meeting2rttm:
                    meeting2rttm.pop(key)
                    logging.warning("Pop {} from rttm scp".format(key))
                    rm_keys.append(key)
                else:
                    new_meeting_scp[key] = meeting_scp[key]
            logging.warning("Keys are removed from wav scp: {}".format(" ".join(rm_keys)))
            new_meeting2rttm = OrderedDict()
            rm_keys = []
            for key in meeting2rttm:
                if key not in common_keys:
                    rm_keys.append(key)
                else:
                    new_meeting2rttm[key] = meeting2rttm[key]
            logging.warning("Keys are removed from rttm scp: {}".format(" ".join(rm_keys)))
            meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm
        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
@@ -69,7 +79,7 @@
                sr=None, frame_shift=0.01):
    frame_shift = int(frame_shift * sr)
    num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
    multi_label = np.zeros([n_spk, num_frame], dtype=int)
    multi_label = np.zeros([n_spk, num_frame], dtype=np.float32)
    for _, st, dur, spk in spk_turns:
        idx = spk_list.index(spk)