志浩
2023-02-15 0917d4ee0edd9c61a53406ed44eaa37e457480fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import logging
import numpy as np
import soundfile
import kaldiio
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import argparse
from collections import OrderedDict
 
 
class MyRunner(MultiProcessRunnerV3):
 
    def prepare(self, parser: argparse.ArgumentParser):
        parser.add_argument("--rttm_list", type=str, required=True)
        parser.add_argument("--wav_scp_list", type=str, required=True)
        parser.add_argument("--out_dir", type=str, required=True)
        parser.add_argument("--n_spk", type=int, default=8)
        parser.add_argument("--remove_sil", default=False, action="store_true")
        parser.add_argument("--max_overlap", default=0, type=int)
        parser.add_argument("--frame_shift", type=float, default=0.01)
        args = parser.parse_args()
 
        rttm_list = [x.strip() for x in open(args.rttm_list, "rt", encoding="utf-8").readlines()]
        meeting2rttm = OrderedDict()
        for rttm_path in rttm_list:
            meeting2rttm.update(self.load_rttm(rttm_path))
 
        wav_scp_list = [x.strip() for x in open(args.wav_scp_list, "rt", encoding="utf-8").readlines()]
        meeting_scp = OrderedDict()
        for scp_path in wav_scp_list:
            meeting_scp.update(load_scp_as_dict(scp_path))
 
        assert len(meeting_scp) == len(meeting2rttm), \
            "Number of wav and rttm mismatch {} != {}".format(len(meeting_scp), len(meeting2rttm))
 
        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
 
        task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()]
        return task_list, None, args
 
    @staticmethod
    def load_rttm(rttm_path):
        meeting2rttm = OrderedDict()
        for one_line in open(rttm_path, "rt", encoding="utf-8"):
            mid = one_line.strip().split(" ")[1]
            if mid not in meeting2rttm:
                meeting2rttm[mid] = []
            meeting2rttm[mid].append(one_line.strip())
 
        return meeting2rttm
 
    def post(self, results_list, args):
        pass
 
 
def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, max_overlap=0,
                sr=None, frame_shift=0.01):
    frame_shift = int(frame_shift * sr)
    num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
    multi_label = np.zeros([n_spk, num_frame], dtype=int)
    for _, st, dur, spk in spk_turns:
        idx = spk_list.index(spk)
 
        st, dur = int(st * sr), int(dur * sr)
        frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift)
        frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift)
        multi_label[idx, frame_st:frame_ed] = 1
 
    if remove_sil:
        speech_count = np.sum(multi_label, axis=0)
        idx = np.nonzero(speech_count)[0]
        multi_label = multi_label[:, idx]
 
    if max_overlap > 0:
        speech_count = np.sum(multi_label, axis=0)
        idx = np.nonzero(speech_count <= max_overlap)[0]
        multi_label = multi_label[:, idx]
 
    label = multi_label.T
    return label  # (T, N)
 
 
def build_labels(wav_path, rttms, n_spk, remove_sil=False, max_overlap=0,
                 sr=16000, frame_shift=0.01):
    wav, sr = soundfile.read(wav_path)
    wav_len = len(wav)
    spk_turns = []
    spk_list = []
    for one_line in rttms:
        parts = one_line.strip().split(" ")
        mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7]
        if spk not in spk_list:
            spk_list.append(spk)
        spk_turns.append((mid, st, dur, spk))
    labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, max_overlap, sr, frame_shift)
    return labels, spk_list
 
 
def process(task_args):
    task_idx, task_list, _, args = task_args
    spk_list_writer = open(os.path.join(args.out_dir, "spk_list.{}.txt".format(task_idx+1)),
                           "wt", encoding="utf-8")
    out_path = os.path.join(args.out_dir, "labels.{}".format(task_idx + 1))
    label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
    for mid, wav_path, rttms in task_list:
        meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil, args.max_overlap,
                                                args.sr, args.frame_shift)
        label_writer(mid, meeting_labels)
        spk_list_writer.write("{} {}\n".format(mid, " ".join(spk_list)))
 
    spk_list_writer.close()
    label_writer.close()
    return None
 
 
if __name__ == '__main__':
    my_runner = MyRunner(process)
    my_runner.run()