志浩
2023-02-15 68b87c91405e11dca5dc64ef1e1b2fdc3e2389f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import logging
import numpy as np
import soundfile
import kaldiio
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import argparse
from collections import OrderedDict
 
 
class MyRunner(MultiProcessRunnerV3):
 
    def prepare(self, parser: argparse.ArgumentParser):
        parser.add_argument("--rttm_list", type=str, required=True)
        parser.add_argument("--wav_scp_list", type=str, required=True)
        parser.add_argument("--out_dir", type=str, required=True)
        parser.add_argument("--n_spk", type=int, default=8)
        parser.add_argument("--remove_sil", default=False, action="store_true")
        parser.add_argument("--max_overlap", default=0, type=int)
        parser.add_argument("--frame_shift", type=float, default=0.01)
        args = parser.parse_args()
 
        rttm_list = [x.strip() for x in open(args.rttm_list, "rt", encoding="utf-8").readlines()]
        meeting2rttm = OrderedDict()
        for rttm_path in rttm_list:
            meeting2rttm.update(self.load_rttm(rttm_path))
 
        wav_scp_list = [x.strip() for x in open(args.wav_scp_list, "rt", encoding="utf-8").readlines()]
        meeting_scp = OrderedDict()
        for scp_path in wav_scp_list:
            meeting_scp.update(load_scp_as_dict(scp_path))
 
        if len(meeting_scp) != len(meeting2rttm):
            logging.warning("Number of wav and rttm mismatch {} != {}".format(
                len(meeting_scp), len(meeting2rttm)))
            common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys())
            logging.warning("Keep {} records.".format(len(common_keys)))
            new_meeting_scp = OrderedDict()
            for key in meeting_scp:
                if key not in common_keys:
                    logging.warning("Pop {} from wav scp".format(key))
                else:
                    new_meeting_scp[key] = meeting_scp[key]
            new_meeting2rttm = OrderedDict()
            for key in meeting2rttm:
                if key not in common_keys:
                    logging.warning("Pop {} from rttm scp".format(key))
                else:
                    new_meeting2rttm[key] = meeting2rttm[key]
 
            meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm
        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
 
        task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()]
        return task_list, None, args
 
    @staticmethod
    def load_rttm(rttm_path):
        meeting2rttm = OrderedDict()
        for one_line in open(rttm_path, "rt", encoding="utf-8"):
            mid = one_line.strip().split(" ")[1]
            if mid not in meeting2rttm:
                meeting2rttm[mid] = []
            meeting2rttm[mid].append(one_line.strip())
 
        return meeting2rttm
 
    def post(self, results_list, args):
        pass
 
 
def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, max_overlap=0,
                sr=None, frame_shift=0.01):
    frame_shift = int(frame_shift * sr)
    num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
    multi_label = np.zeros([n_spk, num_frame], dtype=int)
    for _, st, dur, spk in spk_turns:
        idx = spk_list.index(spk)
 
        st, dur = int(st * sr), int(dur * sr)
        frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift)
        frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift)
        multi_label[idx, frame_st:frame_ed] = 1
 
    if remove_sil:
        speech_count = np.sum(multi_label, axis=0)
        idx = np.nonzero(speech_count)[0]
        multi_label = multi_label[:, idx]
 
    if max_overlap > 0:
        speech_count = np.sum(multi_label, axis=0)
        idx = np.nonzero(speech_count <= max_overlap)[0]
        multi_label = multi_label[:, idx]
 
    label = multi_label.T
    return label  # (T, N)
 
 
def build_labels(wav_path, rttms, n_spk, remove_sil=False, max_overlap=0,
                 sr=16000, frame_shift=0.01):
    wav, sr = soundfile.read(wav_path)
    wav_len = len(wav)
    spk_turns = []
    spk_list = []
    for one_line in rttms:
        parts = one_line.strip().split(" ")
        mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7]
        if spk not in spk_list:
            spk_list.append(spk)
        spk_turns.append((mid, st, dur, spk))
    labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, max_overlap, sr, frame_shift)
    return labels, spk_list
 
 
def process(task_args):
    task_idx, task_list, _, args = task_args
    spk_list_writer = open(os.path.join(args.out_dir, "spk_list.{}.txt".format(task_idx+1)),
                           "wt", encoding="utf-8")
    out_path = os.path.join(args.out_dir, "labels.{}".format(task_idx + 1))
    label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
    for mid, wav_path, rttms in task_list:
        meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil, args.max_overlap,
                                                args.sr, args.frame_shift)
        label_writer(mid, meeting_labels)
        spk_list_writer.write("{} {}\n".format(mid, " ".join(spk_list)))
 
    spk_list_writer.close()
    label_writer.close()
    return None
 
 
if __name__ == '__main__':
    my_runner = MyRunner(process)
    my_runner.run()