speech_asr
2023-03-10 6052e1e7c23c43d495cb5689d6d17450d2d8eb8b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import print_function
from multiprocessing import Pool
import argparse
from tqdm import tqdm
import math
 
 
class MultiProcessRunner:
    def __init__(self, fn):
        self.args = None
        self.process = fn
 
    def run(self):
        parser = argparse.ArgumentParser("")
        # Task-independent options
        parser.add_argument("--nj", type=int, default=16)
        parser.add_argument("--debug", action="store_true", default=False)
        parser.add_argument("--no_pbar", action="store_true", default=False)
        parser.add_argument("--verbose", action="store_ture", default=False)
 
        task_list, args = self.prepare(parser)
        result_list = self.pool_run(task_list, args)
        self.post(result_list, args)
 
    def prepare(self, parser):
        raise NotImplementedError("Please implement the prepare function.")
 
    def post(self, result_list, args):
        raise NotImplementedError("Please implement the post function.")
 
    def pool_run(self, tasks, args):
        results = []
        if args.debug:
            one_result = self.process(tasks[0])
            results.append(one_result)
        else:
            pool = Pool(args.nj)
            for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
                results.append(one_result)
            pool.close()
 
        return results
 
 
class MultiProcessRunnerV2:
    def __init__(self, fn):
        self.args = None
        self.process = fn
 
    def run(self):
        parser = argparse.ArgumentParser("")
        # Task-independent options
        parser.add_argument("--nj", type=int, default=16)
        parser.add_argument("--debug", action="store_true", default=False)
        parser.add_argument("--no_pbar", action="store_true", default=False)
        parser.add_argument("--verbose", action="store_true", default=False)
 
        task_list, args = self.prepare(parser)
        chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
        if args.verbose:
            print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
        subtask_list = [task_list[i*chunk_size: (i+1)*chunk_size] for i in range(args.nj)]
        result_list = self.pool_run(subtask_list, args)
        self.post(result_list, args)
 
    def prepare(self, parser):
        raise NotImplementedError("Please implement the prepare function.")
 
    def post(self, result_list, args):
        raise NotImplementedError("Please implement the post function.")
 
    def pool_run(self, tasks, args):
        results = []
        if args.debug:
            one_result = self.process(tasks[0])
            results.append(one_result)
        else:
            pool = Pool(args.nj)
            for one_result in tqdm(pool.imap(self.process, tasks), total=len(tasks), ascii=True, disable=args.no_pbar):
                results.append(one_result)
            pool.close()
 
        return results
 
 
class MultiProcessRunnerV3(MultiProcessRunnerV2):
    def run(self):
        parser = argparse.ArgumentParser("")
        # Task-independent options
        parser.add_argument("--nj", type=int, default=16)
        parser.add_argument("--debug", action="store_true", default=False)
        parser.add_argument("--no_pbar", action="store_true", default=False)
        parser.add_argument("--verbose", action="store_true", default=False)
        parser.add_argument("--sr", type=int, default=16000)
 
        task_list, shared_param, args = self.prepare(parser)
        chunk_size = int(math.ceil(float(len(task_list)) / args.nj))
        if args.verbose:
            print("Split {} tasks into {} sub-tasks with chunk_size {}".format(len(task_list), args.nj, chunk_size))
        subtask_list = [(i, task_list[i * chunk_size: (i + 1) * chunk_size], shared_param, args)
                        for i in range(args.nj)]
        result_list = self.pool_run(subtask_list, args)
        self.post(result_list, args)