#!/usr/bin/env python # -- coding: UTF-8 import argparse import codecs import glob import logging import os from nara_wpe.utils import stft, istft import numpy as np import scipy.io.wavfile as wf from tqdm import tqdm from test_gss import * def get_parser(): parser = argparse.ArgumentParser("Doing GSS based enhancement.") parser.add_argument( "--wav-scp", type=str, required=True, help="Wav scp file for enhancement.", ) parser.add_argument( "--segments", type=str, required=True, help="Wav scp file for enhancement.", ) parser.add_argument( "--output-dir", type=str, required=True, help="Output directory of GSS enhanced data.", ) return parser def wfread(f): fs, data = wf.read(f) if data.dtype == np.int16: data = np.float32(data) / 32768 return data, fs def wfwrite(z, fs, store_path): tmpwav = np.int16(z * 32768) wf.write(store_path, fs, tmpwav) def main(): args = get_parser().parse_args() # logging info logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" logging.basicConfig(level=logging.INFO, format=logfmt) stft_window, stft_shift = 512, 256 gss = GSS(iterations=20, iterations_post=1) bf = Beamformer("mvdrSouden_ban", "mask_mul") with codecs.open(args.wav_scp, "r") as handle: lines_content = handle.readlines() wav_lines = [*map(lambda x: x[:-1] if x[-1] in ["\n"] else x, lines_content)] cnt = 0 session2spk2dur = {} with codecs.open(args.segments, "r") as handle: for line in handle.readlines(): uttid, spkid, stime, etime = line.strip().split(" ") sessionid = spkid.split("-")[0] if sessionid not in session2spk2dur.keys(): session2spk2dur[sessionid] = {} if spkid not in session2spk2dur[sessionid].keys(): session2spk2dur[sessionid][spkid] = [] session2spk2dur[sessionid][spkid].append((float(stime), float(etime))) # import pdb;pdb.set_trace() for wav_idx in tqdm(range(len(wav_lines)), leave=True, desc="0"): # get wav files from scp file file_list = wav_lines[wav_idx].split(" ") sessionid, wav_list = file_list[0], file_list[1:] signal_list = [] time_activity = [] cnt += 1 logging.info("Processing {}: {}".format(cnt, wav_list[0])) # read all wavs for wav in wav_list: data, fs = wfread(wav) signal_list.append(data) try: obstft = np.stack(signal_list, axis=0) except: minlen = min([len(s) for s in signal_list]) obstft = np.stack([s[:minlen] for s in signal_list]) wavlen = obstft.shape[1] obstft = stft(obstft, stft_window, stft_shift) # get activated timestamps and frequencies speaker_list = [] for spk, dur in session2spk2dur[sessionid].items(): speaker_list.append(spk.split("-")[-1]) time_activity.append(get_time_activity(dur, wavlen, fs)) time_activity.append([True] * wavlen) frequency_activity = get_frequency_activity( time_activity, stft_window, stft_shift ) # import pdb;pdb.set_trace() # generate mask masks = gss(obstft, frequency_activity) masks_bak = masks for i in range(masks.shape[0] - 1): target_mask = masks[i] distortion_mask = np.sum(np.delete(masks, i, axis=0), axis=0) xhat = bf(obstft, target_mask=target_mask, distortion_mask=distortion_mask) xhat = istft(xhat, stft_window, stft_shift) audio_dir = "/".join(wav_list[0].split("/")[:-1]) store_path = ( wav_list[0] .replace(audio_dir, args.output_dir) .replace(".wav", "-{}.wav".format(speaker_list[i])) ) if not os.path.exists(os.path.split(store_path)[0]): os.makedirs(os.path.split(store_path)[0], exist_ok=True) logging.info("Save wav file {}.".format(store_path)) wfwrite(xhat, fs, store_path) masks = masks_bak if __name__ == "__main__": main()