import kaldiio from tqdm import tqdm import os from funasr.utils.misc import load_scp_as_list, load_scp_as_dict import numpy as np import argparse import random import scipy.io as sio import logging logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=logging.INFO) short_spk_list = [] def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000): all_utts = spk2utt[spk] idx_list = list(range(len(all_utts))) random.shuffle(idx_list) count = 0 utt_list = [] for i in idx_list: utt_id = all_utts[i] utt_list.append(utt_id) count += int(utt2frames[utt_id]) if count >= total_len: break if count < 300 and spk not in short_spk_list: logging.warning("{} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300)) short_spk_list.append(spk) ivc_list = [kaldiio.load_mat(utt2ivc[utt])[np.newaxis, :] for utt in utt_list] ivc = np.concatenate(ivc_list, axis=0) ivc = np.mean(ivc, axis=0, keepdims=False) return ivc def process(feat_scp, labels_scp, spk2utt, utt2xvec, utt2frames, args): out_prefix = "{}_parts{:02d}".format(args.out, args.task_id) logger = logging.Logger(out_prefix, logging.INFO) file_handler = logging.FileHandler(out_prefix + ".log", mode="w") file_handler.setLevel(logging.INFO) formatter = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s") file_handler.setFormatter(formatter) logger.addHandler(file_handler) ivc_dim = 256 chunk_size, chunk_shift = args.chunk_size, args.chunk_shift label_weights = 2 ** np.array(list(range(args.n_spk))) feat_writer = kaldiio.WriteHelper(f"ark,scp,f:{out_prefix}_feat.ark,{out_prefix}_feat.scp") ivc_writer = kaldiio.WriteHelper(f"ark,scp,f:{out_prefix}_xvec.ark,{out_prefix}_xvec.scp") label_writer = kaldiio.WriteHelper(f"ark,scp,f:{out_prefix}_label.ark,{out_prefix}_label.scp") train_spk_list = list(spk2utt.keys()) frames_list = [] non_present_spk_list = [] for mid, feat_path in tqdm(feat_scp, total=len(feat_scp), ascii=True, disable=args.no_pbar): if mid not in labels_scp: continue feat = kaldiio.load_mat(feat_path) data = sio.loadmat(labels_scp[mid]) labels, meeting_spk_list = data["labels"].astype(int), [x.strip() for x in data["spk_list"]] if args.add_mid_to_speaker: meeting_spk_list = ["{}_{}".format(mid, x) if not x.startswith(mid) else x for x in meeting_spk_list] if labels.shape[0] != feat.shape[0]: min_len = min(labels.shape[0], feat.shape[0]) labels, feat = labels[:min_len], feat[:min_len] logger.warning("{}: The expected frame_len is {}, but got {}, clip both to {}".format( mid, labels.shape[0], feat.shape[0], min_len)) num_frame = feat.shape[0] num_chunk = int(np.ceil(float(num_frame - chunk_size) / chunk_shift)) + 1 for i in range(num_chunk): st, ed = i*chunk_shift, i*chunk_shift+chunk_size utt_id = "{}-{:05d}-{:05d}".format(mid, st, ed) chunk_feat = feat[st: ed, :] chunk_label = labels[st: ed, :] frame_pad = chunk_size - chunk_label.shape[0] spk_pad = args.n_spk - chunk_label.shape[1] chunk_feat = np.pad(chunk_feat, [(0, frame_pad), (0, 0)], "constant", constant_values=0) chunk_label = np.pad(chunk_label, [(0, frame_pad), (0, spk_pad)], "constant", constant_values=0) feat_writer(utt_id, chunk_feat) spk_idx = list(range(max(args.n_spk, len(meeting_spk_list)))) spk_list = [] if args.mode == "train": random.shuffle(spk_idx) if args.n_spk > len(meeting_spk_list): n = random.randint(len(meeting_spk_list), args.n_spk) spk_list.extend(meeting_spk_list) while len(spk_list) < n: spk = random.choice(train_spk_list) if spk not in spk_list: spk_list.append(spk) spk_list.extend(["None"] * (args.n_spk - n)) else: raise ValueError("Argument n_spk is too small ({} < {}).".format(args.n_spk, len(meeting_spk_list))) else: spk_list.extend(meeting_spk_list) spk_list.extend(["None"] * max(args.n_spk - len(meeting_spk_list), 0)) xvec_list = [] for i, spk in enumerate(spk_list): if spk == "None": spk_xvec = np.zeros((ivc_dim,), dtype=np.float32) elif spk not in spk2utt: # speaker with very short duration spk_xvec = np.zeros((ivc_dim,), dtype=np.float32) # chunk_label[:, i] = 0 if spk not in non_present_spk_list: logging.warning("speaker {}/{} does not appear in spk2utt, since it has very short duration.".format(mid, spk)) non_present_spk_list.append(spk) else: spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 3000)[np.newaxis, :] xvec_list.append(spk_xvec) xvec = np.row_stack(xvec_list) # shuffle speaker embedding according spk_idx xvec = xvec[spk_idx, :] ivc_writer(utt_id, xvec) # shuffle labels according spk_idx feat_label = chunk_label[:, spk_idx] # feat_label = np.sum(feat_label * label_weights[np.newaxis, :chunk_label.shape[1]], axis=1).astype(str).tolist() label_writer(utt_id, feat_label.astype(np.float32)) frames_list.append((mid, feat.shape[0])) logger.info("{:30s}: {:6d} frames split into {:3d} chunks.".format(mid, num_frame, num_chunk)) return frames_list def main(): parser = argparse.ArgumentParser() parser.add_argument("--dir", required=True, type=str, default=None, help="feats.scp") parser.add_argument("--out", required=True, type=str, default=None, help="The prefix of dumpped files.") parser.add_argument("--n_spk", type=int, default=16) parser.add_argument("--use_lfr", default=False, action="store_true") parser.add_argument("--no_pbar", default=False, action="store_true") parser.add_argument("--sr", type=int, default=8000) parser.add_argument("--chunk_size", type=int, default=1600) parser.add_argument("--chunk_shift", type=int, default=400) parser.add_argument("--mode", type=str, default="train", choices=["train", "test"]) parser.add_argument("--task_id", type=int, default=0) parser.add_argument("--task_size", type=int, default=-1) parser.add_argument("--add_mid_to_speaker", type=bool, default=False) args = parser.parse_args() assert args.sr == 8000, "For callhome dataset, the sample rate should be 8000, use --sr 8000." if not os.path.exists(os.path.dirname(args.out)): os.makedirs(os.path.dirname(args.out)) feat_scp = load_scp_as_list(os.path.join(args.dir, "feats.scp")) if args.task_size > 0: feat_scp = feat_scp[args.task_size*args.task_id: args.task_size*(args.task_id+1)] labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp")) utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk")) utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec")) utt2frames = load_scp_as_dict(os.path.join(args.dir, "utt2num_frames")) spk2utt = {} for utt, spk in utt2spk.items(): if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25: if spk not in spk2utt: spk2utt[spk] = [] spk2utt[spk].append(utt) logging.info("Obtain {} speakers.".format(len(spk2utt))) logging.info("Task {:02d}: start dump {} meetings.".format(args.task_id, len(feat_scp))) # random.shuffle(feat_scp) meeting_lens = process(feat_scp, labels_scp, spk2utt, utt2xvec, utt2frames, args) total_frames = sum([x[1] for x in meeting_lens]) logging.info("Total meetings: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames)) if __name__ == '__main__': main()