志浩
2023-02-17 ab828bcf7badb228fdc59647f5c9c75e33acce9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import logging
import numpy as np
import soundfile
import kaldiio
from funasr.utils.job_runner import MultiProcessRunnerV3
from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
import os
import argparse
from collections import OrderedDict
import random
 
 
class MyRunner(MultiProcessRunnerV3):
 
    def prepare(self, parser: argparse.ArgumentParser):
        parser.add_argument("--label_scp", type=str, required=True)
        parser.add_argument("--wav_scp", type=str, required=True)
        parser.add_argument("--utt2spk", type=str, required=True)
        parser.add_argument("--spk2meeting", type=str, required=True)
        parser.add_argument("--utt2xvec", type=str, required=True)
        parser.add_argument("--out_dir", type=str, required=True)
        parser.add_argument("--chunk_size", type=int, default=16)
        parser.add_argument("--chunk_shift", type=int, default=4)
        parser.add_argument("--frame_shift", type=float, default=0.01)
        args = parser.parse_args()
 
        if not os.path.exists(args.out_dir):
            os.makedirs(args.out_dir)
 
        label_list = load_scp_as_list(args.label_scp)
        wav_scp = load_scp_as_dict(args.wav_scp)
        utt2spk = load_scp_as_dict(args.utt2spk)
        utt2xvec = load_scp_as_dict(args.utt2xvec)
        spk2meeting = load_scp_as_dict(args.spk2meeting)
 
        meeting2spks = OrderedDict()
        for spk, meeting in spk2meeting.items():
            if meeting not in meeting2spks:
                meeting2spks[meeting] = []
            meeting2spks[meeting].append(spk)
 
        spk2utts = OrderedDict()
        for utt, spk in utt2spk.items():
            if spk not in spk2utts:
                spk2utts[spk] = []
            spk2utts[spk].append(utt)
 
        return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args
 
    def post(self, results_list, args):
        pass
 
 
def process(task_args):
    task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args
    out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1))
    wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
 
    out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
    wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
 
    out_path = os.path.join(args.out_dir, "label.{}".format(task_idx + 1))
    label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
 
    idx = 0
    for _, label_path in task_list:
        rand_shift = random.randint(0, int(args.chunk_shift / args.frame_shift))
        whole_label = kaldiio.load_mat(label_path)
        whole_label = whole_label[rand_shift:]
        num_chunk = (whole_label.shape[0] - args.chunk_size) // args.chunk_shift
        for i in range(num_chunk):
            utt_id = "part{}_utt{:10d}".format(task_idx + 1, idx + 1)
 
 
    wav_mix_writer.close()
    wav_sep_writer.close()
    label_writer.close()
    return None
 
 
if __name__ == '__main__':
    my_runner = MyRunner(process)
    my_runner.run()