yhliang
2023-08-10 08ee9e6aacc2e306211d393f6e8ce3a7f3620102
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python
# -- coding: UTF-8
 
import argparse
import codecs
import glob
import logging
import os
from nara_wpe.utils import stft, istft
import numpy as np
import scipy.io.wavfile as wf
from tqdm import tqdm
 
from test_gss import *
 
 
def get_parser():
    parser = argparse.ArgumentParser("Doing GSS based enhancement.")
    parser.add_argument(
        "--wav-scp",
        type=str,
        required=True,
        help="Wav scp file for enhancement.",
    )
    parser.add_argument(
        "--segments",
        type=str,
        required=True,
        help="Wav scp file for enhancement.",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        required=True,
        help="Output directory of GSS enhanced data.",
    )
 
    return parser
 
 
def wfread(f):
    fs, data = wf.read(f)
    if data.dtype == np.int16:
        data = np.float32(data) / 32768
    return data, fs
 
 
def wfwrite(z, fs, store_path):
    tmpwav = np.int16(z * 32768)
    wf.write(store_path, fs, tmpwav)
 
 
def main():
    args = get_parser().parse_args()
 
    # logging info
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    logging.basicConfig(level=logging.INFO, format=logfmt)
 
    stft_window, stft_shift = 512, 256
    gss = GSS(iterations=20, iterations_post=1)
    bf = Beamformer("mvdrSouden_ban", "mask_mul")
 
    with codecs.open(args.wav_scp, "r") as handle:
        lines_content = handle.readlines()
    wav_lines = [*map(lambda x: x[:-1] if x[-1] in ["\n"] else x, lines_content)]
 
    cnt = 0
    
    session2spk2dur = {}
    with codecs.open(args.segments, "r") as handle:
        for line in handle.readlines():
            uttid, spkid, stime, etime = line.strip().split(" ")
            sessionid = spkid.split("-")[0]
            if sessionid not in session2spk2dur.keys():
                session2spk2dur[sessionid] = {}
            if spkid not in session2spk2dur[sessionid].keys():
                session2spk2dur[sessionid][spkid] = []
            session2spk2dur[sessionid][spkid].append((float(stime), float(etime)))
    # import pdb;pdb.set_trace()
    
    for wav_idx in tqdm(range(len(wav_lines)), leave=True, desc="0"):
        # get wav files from scp file
        file_list = wav_lines[wav_idx].split(" ")
        sessionid, wav_list = file_list[0], file_list[1:]
 
        signal_list = []
        time_activity = []
        cnt += 1
        logging.info("Processing {}: {}".format(cnt, wav_list[0]))
 
        # read all wavs
        for wav in wav_list:
            data, fs = wfread(wav)
            signal_list.append(data)
        try:
            obstft = np.stack(signal_list, axis=0)
        except:
            minlen = min([len(s) for s in signal_list])
            obstft = np.stack([s[:minlen] for s in signal_list])
        wavlen = obstft.shape[1]
        obstft = stft(obstft, stft_window, stft_shift)
 
        # get activated timestamps and frequencies
        speaker_list = []
        for spk, dur in session2spk2dur[sessionid].items():
            speaker_list.append(spk.split("-")[-1])
            time_activity.append(get_time_activity(dur, wavlen, fs))
        time_activity.append([True] * wavlen)
        frequency_activity = get_frequency_activity(
            time_activity, stft_window, stft_shift
        )
        # import pdb;pdb.set_trace()
 
        # generate mask
        masks = gss(obstft, frequency_activity)
        masks_bak = masks
 
        for i in range(masks.shape[0] - 1):
            target_mask = masks[i]
            distortion_mask = np.sum(np.delete(masks, i, axis=0), axis=0)
            xhat = bf(obstft, target_mask=target_mask, distortion_mask=distortion_mask)
            xhat = istft(xhat, stft_window, stft_shift)
            audio_dir = "/".join(wav_list[0].split("/")[:-1])
            store_path = (
                wav_list[0]
                .replace(audio_dir, args.output_dir)
                .replace(".wav", "-{}.wav".format(speaker_list[i]))
            )
            if not os.path.exists(os.path.split(store_path)[0]):
                os.makedirs(os.path.split(store_path)[0], exist_ok=True)
 
            logging.info("Save wav file {}.".format(store_path))
            wfwrite(xhat, fs, store_path)
            masks = masks_bak
 
 
if __name__ == "__main__":
    main()