游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import kaldiio
from tqdm import tqdm
import os
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import numpy as np
import argparse
from kaldiio import WriteHelper
 
 
def calc_global_ivc(spk, spk2utt, utt2ivc):
    ivc_list = [kaldiio.load_mat(utt2ivc[utt])[np.newaxis, :] for utt in spk2utt[spk]]
    ivc = np.concatenate(ivc_list, axis=0)
    ivc = np.mean(ivc, axis=0, keepdims=False)
    return ivc
 
 
def process(idx_scp, spk2utt, utt2xvec, meeting2spk_list, args):
    out_prefix = args.out
 
    ivc_dim = 256
    print("ivc_dim = {}".format(ivc_dim))
    out_prefix = out_prefix + "_parts00_xvec"
    ivc_writer = WriteHelper(f"ark,scp,f:{out_prefix}.ark,{out_prefix}.scp")
    idx_writer = open(out_prefix + ".idx", "wt")
    spk2xvec = {}
    if args.emb_type == "global":
        print("Use global speaker embedding.")
        for spk in spk2utt.keys():
            spk2xvec[spk] = calc_global_ivc(spk, spk2utt, utt2xvec)[np.newaxis, :]
 
    frames_list = []
    for utt_id in tqdm(idx_scp, total=len(idx_scp), ascii=True, disable=args.no_pbar):
        mid = utt_id.split("-")[0]
        idx_writer.write(utt_id+"\n")
 
        xvec_list = []
        for spk in meeting2spk_list[mid]:
            spk_xvec = spk2xvec[spk]
            xvec_list.append(spk_xvec)
        for _ in range(args.n_spk - len(xvec_list)):
            xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32))
        xvec = np.row_stack(xvec_list)
        ivc_writer(utt_id, xvec)
 
        frames_list.append((mid, 1))
    return frames_list
 
 
def calc_spk_list(rttms):
    spk_list = []
    for one_line in rttms:
        parts = [x for x in one_line.strip().split(" ") if x != ""]
        mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
        spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
        if spk_name.isdigit():
            spk_name = "{}_S{:03d}".format(mid, int(spk_name))
        else:
            spk_name = "{}_{}".format(mid, spk_name)
        if spk_name not in spk_list:
            spk_list.append(spk_name)
 
    return spk_list
 
 
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dir", required=True, type=str, default=None,
                        help="feats.scp")
    parser.add_argument("--out", required=True, type=str, default=None,
                        help="The prefix of dumpped files.")
    parser.add_argument("--n_spk", type=int, default=4)
    parser.add_argument("--no_pbar", default=False, action="store_true")
    parser.add_argument("--sr", type=int, default=16000)
    parser.add_argument("--emb_type", type=str, default="rand")
    args = parser.parse_args()
 
    if not os.path.exists(os.path.dirname(args.out)):
        os.makedirs(os.path.dirname(args.out))
 
    idx_scp = open(os.path.join(args.dir, "idx"), "r").readlines()
    idx_scp = [x.strip() for x in idx_scp]
    meeting2rttms = {}
    for one_line in open(os.path.join(args.dir, "sys.rttm"), "rt"):
        parts = [x for x in one_line.strip().split(" ") if x != ""]
        mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
        if mid not in meeting2rttms:
            meeting2rttms[mid] = []
        meeting2rttms[mid].append(one_line)
 
    utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
    utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
 
    spk2utt = {}
    for utt, spk in utt2spk.items():
        if utt in utt2xvec:
            if spk not in spk2utt:
                spk2utt[spk] = []
            spk2utt[spk].append(utt)
 
    meeting2spk_list = {}
    for mid, rttms in meeting2rttms.items():
        meeting2spk_list[mid] = calc_spk_list(rttms)
        new_spk_list = []
        for spk in meeting2spk_list[mid]:
            if spk in spk2utt:
                new_spk_list.append(spk)
        if len(new_spk_list) != len(meeting2spk_list[mid]):
            print("{}: Reduce speaker number from {}(according rttm) to {}(according x-vectors)".format(
                mid, len(meeting2spk_list[mid]), len(new_spk_list)))
        meeting2spk_list[mid] = new_spk_list
 
    meeting_lens = process(idx_scp, spk2utt, utt2xvec, meeting2spk_list, args)
    print("Total meetings: {:6d}".format(len(meeting_lens)))
 
 
if __name__ == '__main__':
    main()