import random import numpy as np import os import sys if __name__=="__main__": path = sys.argv[1] # dump2/raw/Train_Ali_far wav_scp_file = open(path+"/wav.scp", 'r') wav_scp = wav_scp_file.readlines() wav_scp_file.close() spk2id_file = open(path+"/spk2id", 'r') spk2id = spk2id_file.readlines() spk2id_file.close() embedding_scp_file = open(path + "/oracle_embedding.scp", 'r') embedding_scp = embedding_scp_file.readlines() embedding_scp_file.close() embedding_map = {} for line in embedding_scp: spk = line.strip().split(' ')[0] if spk not in embedding_map.keys(): emb = np.load(line.strip().split(' ')[1]) embedding_map[spk] = emb meeting_map_tmp = {} global_spk_list = [] for line in spk2id: line_list = line.strip().split(' ') meeting = line_list[0].split('-')[0] spk_id = line_list[0].split('-')[-1].split('_')[-1] spk = meeting+'_' + spk_id global_spk_list.append(spk) if meeting in meeting_map_tmp.keys(): meeting_map_tmp[meeting].append(spk) else: meeting_map_tmp[meeting] = [spk] for meeting in meeting_map_tmp.keys(): num = len(meeting_map_tmp[meeting]) if num < 4: global_spk_list_tmp = global_spk_list[: ] for spk in meeting_map_tmp[meeting]: global_spk_list_tmp.remove(spk) padding_spk = random.sample(global_spk_list_tmp, 4 - num) meeting_map_tmp[meeting] = meeting_map_tmp[meeting] + padding_spk meeting_map = {} os.system('mkdir -p ' + path + '/oracle_profile_padding') for meeting in meeting_map_tmp.keys(): emb_list = [] for i in range(len(meeting_map_tmp[meeting])): spk = meeting_map_tmp[meeting][i] emb_list.append(embedding_map[spk]) profile = np.vstack(emb_list) np.save(path + '/oracle_profile_padding/' + meeting + '.npy',profile) meeting_map[meeting] = path + '/oracle_profile_padding/' + meeting + '.npy' profile_scp = open(path + '/oracle_profile_padding.scp', 'w') profile_map_scp = open(path + '/oracle_profile_padding_spk_list', 'w') for line in wav_scp: uttid = line.strip().split(' ')[0] meeting = uttid.split('-')[0] profile_scp.write(uttid+' ' + meeting_map[meeting] + '\n') profile_map_scp.write(uttid+' ' + '$'.join(meeting_map_tmp[meeting]) + '\n') profile_scp.close() profile_map_scp.close()