from __future__ import print_function import numpy as np import os import kaldiio from multiprocessing import Pool import argparse from tqdm import tqdm import math class MultiProcessRunner: def __init__(self, fn): self.args = None self.process = fn def run(self): parser = argparse.ArgumentParser("") # Task-independent options parser.add_argument("--nj", type=int, default=16) parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--no_pbar", action="store_true", default=False) parser.add_argument("--verbose", action="store_ture", default=False) task_list, args = self.prepare(parser) result_list = self.pool_run(task_list, args) self.post(result_list, args) def prepare(self, parser): raise NotImplementedError("Please implement the prepare function.") def post(self, result_list, args): raise NotImplementedError("Please implement the post function.") def pool_run(self, tasks, args): results = [] if args.debug: one_result = self.process(tasks[0]) results.append(one_result) else: pool = Pool(args.nj) for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar): results.append(one_result) pool.close() return results class MultiProcessRunnerV2: def __init__(self, fn): self.args = None self.process = fn def run(self): parser = argparse.ArgumentParser("") # Task-independent options parser.add_argument("--nj", type=int, default=16) parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--no_pbar", action="store_true", default=False) parser.add_argument("--verbose", action="store_true", default=False) task_list, args = self.prepare(parser) chunk_size = int(math.ceil(float(len(task_list)) / args.nj)) if args.verbose: print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size)) subtask_list = [task_list[i*chunk_size: (i+1)*chunk_size] for i in range(args.nj)] result_list = self.pool_run(subtask_list, args) self.post(result_list, args) def prepare(self, parser): raise NotImplementedError("Please implement the prepare function.") def post(self, result_list, args): raise NotImplementedError("Please implement the post function.") def pool_run(self, tasks, args): results = [] if args.debug: one_result = self.process(tasks[0]) results.append(one_result) else: pool = Pool(args.nj) for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar): results.append(one_result) pool.close() return results class MultiProcessRunnerV3(MultiProcessRunnerV2): def run(self): parser = argparse.ArgumentParser("") # Task-independent options parser.add_argument("--nj", type=int, default=16) parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--no_pbar", action="store_true", default=False) parser.add_argument("--verbose", action="store_true", default=False) parser.add_argument("--sr", type=int, default=16000) task_list, shared_param, args = self.prepare(parser) chunk_size = int(math.ceil(float(len(task_list)) / args.nj)) if args.verbose: print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size)) subtask_list = [(i, task_list[i * chunk_size: (i + 1) * chunk_size], shared_param, args) for i in range(args.nj)] result_list = self.pool_run(subtask_list, args) self.post(result_list, args) class MyRunner(MultiProcessRunnerV3): def prepare(self, parser): assert isinstance(parser, argparse.ArgumentParser) parser.add_argument("enroll_dir", type=str) parser.add_argument("trial_in", type=str) parser.add_argument("trial_out", type=str) args = parser.parse_args() if not os.path.exists(os.path.dirname(args.trial_out)): os.makedirs(os.path.dirname(args.trial_out)) flist_path = os.path.join(args.enroll_dir, "spk2xvec.flist") spk2xvec = {} for _path in open(flist_path, "r"): for key, value in kaldiio.load_ark(_path.strip()): if "-enroll" in key: key = key.replace("-enroll", "") spk2xvec[key] = value flist_path = os.path.join(args.enroll_dir, "utt2xvec.flist") utt2xvec = {} for _path in open(flist_path, 'r'): for key, value in kaldiio.load_ark(_path.strip()): utt2xvec[key] = value task_list = [one_line.strip().split(" ") for one_line in open(args.trial_in, "rt")] return task_list, [spk2xvec, utt2xvec], args def post(self, results_list, args): with open(args.trial_out, "wt") as fs: for results in results_list: for one_item in results: fs.write(one_item+"\n") def process(task_args): task_id, task_list, [spk2xvec, utt2xvec], args = task_args results = [] for spk, utt, _ in task_list: xvec = utt2xvec[utt] normed_x = xvec / np.linalg.norm(xvec) normed_y = spk2xvec[spk] / np.linalg.norm(spk2xvec[spk]) score = np.sum(normed_x * normed_y) results.append("{} {} {:.5f}".format(spk, utt, score)) return results if __name__ == '__main__': my_runner = MyRunner(process) my_runner.run()