志浩
2023-02-16 842df33fa23331e819965324df5d9e790eccbf9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import os
import sys
import argparse
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import librosa
import soundfile as sf
from copy import deepcopy
import json
from tqdm import tqdm
 
 
class MyRunner(MultiProcessRunnerV3):
    def prepare(self, parser):
        assert isinstance(parser, argparse.ArgumentParser)
        parser.add_argument("wav_scp", type=str)
        parser.add_argument("rttm", type=str)
        parser.add_argument("out_dir", type=str)
        parser.add_argument("--min_dur", type=float, default=2.0)
        parser.add_argument("--max_spk_num", type=int, default=4)
        args = parser.parse_args()
 
        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
 
        wav_scp = load_scp_as_list(args.wav_scp)
        meeting2rttms = {}
        for one_line in open(args.rttm, "rt"):
            parts = [x for x in one_line.strip().split(" ") if x != ""]
            mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
            if mid not in meeting2rttms:
                meeting2rttms[mid] = []
            meeting2rttms[mid].append(one_line)
 
        task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp]
        return task_list, None, args
 
    def post(self, result_list, args):
        count = [0, 0]
        for result in result_list:
            count[0] += result[0]
            count[1] += result[1]
        print("Found {} speakers, extracted {}.".format(count[1], count[0]))
 
 
# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
def calc_multi_label(rttms, length, sr=8000, max_spk_num=4):
    labels = np.zeros([max_spk_num, length], int)
    spk_list = []
    for one_line in rttms:
        parts = [x for x in one_line.strip().split(" ") if x != ""]
        mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
        spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
        if spk_name.isdigit():
            spk_name = "{}_S{:03d}".format(mid, int(spk_name))
        else:
            spk_name = "{}_{}".format(mid, spk_name)
        if spk_name not in spk_list:
            spk_list.append(spk_name)
        st, dur = int(st*sr), int(dur*sr)
        idx = spk_list.index(spk_name)
        labels[idx, st:st+dur] = 1
    return labels, spk_list
 
 
def get_nonoverlap_turns(multi_label, spk_list):
    turns = []
    label = np.sum(multi_label, axis=0) == 1
    spk, in_turn, st = None, False, 0
    for i in range(len(label)):
        if not in_turn and label[i]:
            st, in_turn = i, True
            spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
        if in_turn:
            if not label[i]:
                in_turn = False
                turns.append([st, i, spk])
            elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
                turns.append([st, i, spk])
                st, in_turn = i, True
                spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
    if in_turn:
        turns.append([st, len(label), spk])
    return turns
 
 
def process(task_args):
    task_id, task_list, _, args = task_args
    spk_count = [0, 0]
    for mid, wav_path, rttms in task_list:
        wav, sr = sf.read(wav_path, dtype="int16")
        assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
        multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num)
        turns = get_nonoverlap_turns(multi_label, spk_list)
        extracted_spk = []
        count = 1
        for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar):
            if (ed - st) >= args.min_dur * args.sr:
                seg = wav[st: ed]
                save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
                if not os.path.exists(os.path.dirname(save_path)):
                    os.makedirs(os.path.dirname(save_path))
                sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
                count += 1
                if spk not in extracted_spk:
                    extracted_spk.append(spk)
        if len(extracted_spk) != len(spk_list):
            print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
                mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
            ))
        spk_count[0] += len(extracted_spk)
        spk_count[1] += len(spk_list)
    return spk_count
 
 
if __name__ == '__main__':
    my_runner = MyRunner(process)
    my_runner.run()