import kaldiio from tqdm import tqdm import os from funasr.utils.misc import load_scp_as_list, load_scp_as_dict import numpy as np import argparse from kaldiio import WriteHelper def calc_global_ivc(spk, spk2utt, utt2ivc): ivc_list = [kaldiio.load_mat(utt2ivc[utt])[np.newaxis, :] for utt in spk2utt[spk]] ivc = np.concatenate(ivc_list, axis=0) ivc = np.mean(ivc, axis=0, keepdims=False) return ivc def process(idx_scp, spk2utt, utt2xvec, meeting2spk_list, args): out_prefix = args.out ivc_dim = 256 print("ivc_dim = {}".format(ivc_dim)) out_prefix = out_prefix + "_parts00_xvec" ivc_writer = WriteHelper(f"ark,scp,f:{out_prefix}.ark,{out_prefix}.scp") idx_writer = open(out_prefix + ".idx", "wt") spk2xvec = {} if args.emb_type == "global": print("Use global speaker embedding.") for spk in spk2utt.keys(): spk2xvec[spk] = calc_global_ivc(spk, spk2utt, utt2xvec)[np.newaxis, :] frames_list = [] for utt_id in tqdm(idx_scp, total=len(idx_scp), ascii=True, disable=args.no_pbar): mid = utt_id.split("-")[0] idx_writer.write(utt_id+"\n") xvec_list = [] for spk in meeting2spk_list[mid]: spk_xvec = spk2xvec[spk] xvec_list.append(spk_xvec) for _ in range(args.n_spk - len(xvec_list)): xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32)) xvec = np.row_stack(xvec_list) ivc_writer(utt_id, xvec) frames_list.append((mid, 1)) return frames_list def calc_spk_list(rttms): spk_list = [] for one_line in rttms: parts = [x for x in one_line.strip().split(" ") if x != ""] mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7] spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "") if spk_name.isdigit(): spk_name = "{}_S{:03d}".format(mid, int(spk_name)) else: spk_name = "{}_{}".format(mid, spk_name) if spk_name not in spk_list: spk_list.append(spk_name) return spk_list def main(): parser = argparse.ArgumentParser() parser.add_argument("--dir", required=True, type=str, default=None, help="feats.scp") parser.add_argument("--out", required=True, type=str, default=None, help="The prefix of dumpped files.") parser.add_argument("--n_spk", type=int, default=4) parser.add_argument("--no_pbar", default=False, action="store_true") parser.add_argument("--sr", type=int, default=16000) parser.add_argument("--emb_type", type=str, default="rand") args = parser.parse_args() if not os.path.exists(os.path.dirname(args.out)): os.makedirs(os.path.dirname(args.out)) idx_scp = open(os.path.join(args.dir, "idx"), "r").readlines() idx_scp = [x.strip() for x in idx_scp] meeting2rttms = {} for one_line in open(os.path.join(args.dir, "sys.rttm"), "rt"): parts = [x for x in one_line.strip().split(" ") if x != ""] mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7] if mid not in meeting2rttms: meeting2rttms[mid] = [] meeting2rttms[mid].append(one_line) utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk")) utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec")) spk2utt = {} for utt, spk in utt2spk.items(): if utt in utt2xvec: if spk not in spk2utt: spk2utt[spk] = [] spk2utt[spk].append(utt) meeting2spk_list = {} for mid, rttms in meeting2rttms.items(): meeting2spk_list[mid] = calc_spk_list(rttms) new_spk_list = [] for spk in meeting2spk_list[mid]: if spk in spk2utt: new_spk_list.append(spk) if len(new_spk_list) != len(meeting2spk_list[mid]): print("{}: Reduce speaker number from {}(according rttm) to {}(according x-vectors)".format( mid, len(meeting2spk_list[mid]), len(new_spk_list))) meeting2spk_list[mid] = new_spk_list meeting_lens = process(idx_scp, spk2utt, utt2xvec, meeting2spk_list, args) print("Total meetings: {:6d}".format(len(meeting_lens))) if __name__ == '__main__': main()