kongdeqiang
6 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
""" This implementation is adapted from https://github.com/wenet-e2e/wekws/blob/main/wekws/bin/compute_det.py."""
 
import os
import json
import logging
import argparse
import threading
 
import kaldiio
import torch
from funasr.utils.kws_utils import split_mixed_label
 
 
class thread_wrapper(threading.Thread):
    def __init__(self, func, args=()):
        super(thread_wrapper, self).__init__()
        self.func = func
        self.args = args
        self.result = []
 
    def run(self):
        self.result = self.func(*self.args)
 
    def get_result(self):
        try:
            return self.result
        except Exception:
            return None
 
 
def space_mixed_label(input_str):
    splits = split_mixed_label(input_str)
    space_str = ''.join(f'{sub} ' for sub in splits)
    return space_str.strip()
 
 
def read_lists(list_file):
    lists = []
    with open(list_file, 'r', encoding='utf8') as fin:
        for line in fin:
            if line.strip() != '':
                lists.append(line.strip())
    return lists
 
 
def make_pair(wav_lists, trans_lists):
    logging.info('make pair for wav-trans list')
 
    trans_table = {}
    for line in trans_lists:
        arr = line.strip().replace('\t', ' ').split()
        if len(arr) < 2:
            logging.debug('invalid line in trans file: {}'.format(
                line.strip()))
            continue
 
        trans_table[arr[0]] = line.replace(arr[0],'').strip()
 
    lists = []
    for line in wav_lists:
        arr = line.strip().replace('\t', ' ').split()
        if len(arr) == 2 and arr[0] in trans_table:
            lists.append(
                dict(key=arr[0],
                     txt=trans_table[arr[0]],
                     wav=arr[1],
                     sample_rate=16000))
        else:
            logging.debug("can't find corresponding trans for key: {}".format(
                arr[0]))
            continue
 
    return lists
 
 
def count_duration(tid, data_lists):
    results = []
 
    for obj in data_lists:
        assert 'key' in obj
        assert 'wav' in obj
        assert 'txt' in obj
        key = obj['key']
        wav_file = obj['wav']
        txt = obj['txt']
 
        try:
            rate, waveform = kaldiio.load_mat(wav_file)
            waveform = torch.tensor(waveform, dtype=torch.float32)
            waveform = waveform.unsqueeze(0)
            frames = len(waveform[0])
            duration = frames / float(rate)
        except:
            logging.info(f'load file failed: {wav_file}')
            duration = 0.0
 
        obj['duration'] = duration
        results.append(obj)
 
    return results
 
 
def load_data_and_score(keywords_list, data_file, trans_file, score_file):
    # score_table: {uttid: [keywordlist]}
    score_table = {}
    with open(score_file, 'r', encoding='utf8') as fin:
        # read score file and store in table
        for line in fin:
            arr = line.strip().split()
            key = arr[0]
            is_detected = arr[1]
            if is_detected == 'detected':
                if key not in score_table:
                    score_table.update(
                        {key: {
                            'kw': space_mixed_label(arr[2]),
                            'confi': float(arr[3])
                        }})
            else:
                if key not in score_table:
                    score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
 
    wav_lists = read_lists(data_file)
    trans_lists = read_lists(trans_file)
    data_lists = make_pair(wav_lists, trans_lists)
    logging.info(f'origin list samples: {len(data_lists)}')
 
    # count duration for each wave
    num_workers = 8
    start = 0
    step = int(len(data_lists) / num_workers)
    tasks = []
    for idx in range(num_workers):
        if idx != num_workers - 1:
            task = thread_wrapper(count_duration,
                                  (idx, data_lists[start:start + step]))
        else:
            task = thread_wrapper(count_duration, (idx, data_lists[start:]))
        task.start()
        tasks.append(task)
        start += step
 
    duration_lists = []
    for task in tasks:
        task.join()
        duration_lists += task.get_result()
    logging.info(f'after list samples: {len(duration_lists)}')
 
    # build empty structure for keyword-filler infos
    keyword_filler_table = {}
    for keyword in keywords_list:
        keyword = space_mixed_label(keyword)
        keyword_filler_table[keyword] = {}
        keyword_filler_table[keyword]['keyword_table'] = {}
        keyword_filler_table[keyword]['keyword_duration'] = 0.0
        keyword_filler_table[keyword]['filler_table'] = {}
        keyword_filler_table[keyword]['filler_duration'] = 0.0
 
    for obj in duration_lists:
        assert 'key' in obj
        assert 'wav' in obj
        assert 'txt' in obj
        assert 'duration' in obj
 
        key = obj['key']
        wav_file = obj['wav']
        txt = obj['txt']
        txt = space_mixed_label(txt)
        txt_regstr_lrblk = ' ' + txt + ' '
        duration = obj['duration']
        assert key in score_table
 
        for keyword in keywords_list:
            keyword = space_mixed_label(keyword)
            keyword_regstr_lrblk = ' ' + keyword + ' '
            if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
                if keyword == score_table[key]['kw']:
                    keyword_filler_table[keyword]['keyword_table'].update(
                        {key: score_table[key]['confi']})
                else:
                    # uttrance detected but not match this keyword
                    keyword_filler_table[keyword]['keyword_table'].update(
                        {key: -1.0})
                keyword_filler_table[keyword]['keyword_duration'] += duration
            else:
                if keyword == score_table[key]['kw']:
                    keyword_filler_table[keyword]['filler_table'].update(
                        {key: score_table[key]['confi']})
                else:
                    # uttrance if detected, which is not FA for this keyword
                    keyword_filler_table[keyword]['filler_table'].update(
                        {key: -1.0})
                keyword_filler_table[keyword]['filler_duration'] += duration
 
    return keyword_filler_table
 
 
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='compute det curve')
    parser.add_argument('--keywords',
                        type=str,
                        required=True,
                        help='preset keyword str, input all keywords')
    parser.add_argument('--test_data', required=True, help='test data file')
    parser.add_argument('--trans_data',
                        required=True,
                        default='',
                        help='transcription of test data')
    parser.add_argument('--score_file', required=True, help='score file')
    parser.add_argument('--step',
                        type=float,
                        default=0.001,
                        help='threshold step')
    parser.add_argument('--stats_dir',
                        required=True,
                        help='to save det stats files')
    args = parser.parse_args()
 
    root_logger = logging.getLogger()
    handlers = root_logger.handlers[:]
    for handler in handlers:
        root_logger.removeHandler(handler)
        handler.close()
 
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')
 
    keywords_list = args.keywords.strip().split(',')
    keyword_filler_table = load_data_and_score(keywords_list, args.test_data,
                                               args.trans_data,
                                               args.score_file)
 
    stats_files = {}
    for keyword in keywords_list:
        keyword = space_mixed_label(keyword)
        keyword_dur = keyword_filler_table[keyword]['keyword_duration']
        keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
        filler_dur = keyword_filler_table[keyword]['filler_duration']
        filler_num = len(keyword_filler_table[keyword]['filler_table'])
        if keyword_num <= 0:
            print('Can\'t compute det for {} without positive sample'.format(keyword))
            continue
        if filler_num <= 0:
            print('Can\'t compute det for {} without negative sample'.format(keyword))
            continue
 
        logging.info('Computing det for {}'.format(keyword))
        logging.info('  Keyword duration: {} Hours, wave number: {}'.format(
            keyword_dur / 3600.0, keyword_num))
        logging.info('  Filler duration: {} Hours'.format(filler_dur / 3600.0))
 
        stats_file = os.path.join(args.stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
        with open(stats_file, 'w', encoding='utf8') as fout:
            threshold = 0.0
            while threshold <= 1.0:
                num_false_reject = 0
                num_true_detect = 0
                # transverse the all keyword_table
                for key, confi in keyword_filler_table[keyword][
                        'keyword_table'].items():
                    if confi < threshold:
                        num_false_reject += 1
                    else:
                        num_true_detect += 1
 
                num_false_alarm = 0
                # transverse the all filler_table
                for key, confi in keyword_filler_table[keyword][
                        'filler_table'].items():
                    if confi >= threshold:
                        num_false_alarm += 1
                        # print(f'false alarm: {keyword}, {key}, {confi}')
 
                # false_reject_rate = num_false_reject / keyword_num
                true_detect_rate = num_true_detect / keyword_num
 
                num_false_alarm = max(num_false_alarm, 1e-6)
                false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
                false_alarm_rate = num_false_alarm / filler_num
 
                fout.write('{:.3f} {:.6f} {:.6f} {:.6f}\n'.format(
                    threshold, true_detect_rate, false_alarm_rate,
                    false_alarm_per_hour))
                threshold += args.step
 
        stats_files[keyword] = stats_file