#!/usr/bin/env python # _*_ coding: UTF-8 _*_ import argparse import codecs import os import logging from multiprocessing import Pool import numpy as np import scipy.io.wavfile as wf from nara_wpe.utils import istft, stft from nara_wpe.wpe import wpe_v8 as wpe def wpe_worker( wav_scp, audio_dir="", output_dir="", channel=0, processing_id=None, processing_num=None, ): sampling_rate = 16000 iterations = 5 stft_options = dict( size=512, shift=128, window_length=None, fading=True, pad=True, symmetric_window=False, ) with codecs.open(wav_scp, "r") as handle: lines_content = handle.readlines() wav_lines = [*map(lambda x: x[:-1] if x[-1] in ["\n"] else x, lines_content)] for wav_idx in range(len(wav_lines)): if processing_id is None: processing_token = True else: if wav_idx % processing_num == processing_id: processing_token = True else: processing_token = False if processing_token: wav_list = wav_lines[wav_idx].split(" ") file_exist = True for wav_path in wav_list: file_exist = file_exist and os.path.exists( wav_path.replace(audio_dir, output_dir) ) if not file_exist: break if not file_exist: logging.info("wait to process {} : {}".format(wav_idx, wav_list[0])) signal_list = [] for f in wav_list: _, data = wf.read(f) data = data[:, channel - 1] if data.dtype == np.int16: data = np.float32(data) / 32768 signal_list.append(data) min_len = len(signal_list[0]) max_len = len(signal_list[0]) for i in range(1, len(signal_list)): min_len = min(min_len, len(signal_list[i])) max_len = max(max_len, len(signal_list[i])) if min_len != max_len: for i in range(len(signal_list)): signal_list[i] = signal_list[i][:min_len] y = np.stack(signal_list, axis=0) Y = stft(y, **stft_options).transpose(2, 0, 1) Z = wpe(Y, iterations=iterations, statistics_mode="full").transpose( 1, 2, 0 ) z = istft(Z, size=stft_options["size"], shift=stft_options["shift"]) for d in range(len(signal_list)): store_path = wav_list[d].replace(audio_dir, output_dir) if not os.path.exists(os.path.split(store_path)[0]): os.makedirs(os.path.split(store_path)[0], exist_ok=True) tmpwav = np.int16(z[d, :] * 32768) wf.write(store_path, sampling_rate, tmpwav) else: logging.info("file exist {} : {}".format(wav_idx, wav_list[0])) return None def wpe_manager( wav_scp, processing_num=1, audio_dir="", output_dir="", channel=1 ): if processing_num > 1: pool = Pool(processes=processing_num) for i in range(processing_num): pool.apply_async( wpe_worker, kwds={ "wav_scp": wav_scp, "processing_id": i, "processing_num": processing_num, "audio_dir": audio_dir, "output_dir": output_dir, }, ) pool.close() pool.join() else: wpe_worker(wav_scp, audio_dir=audio_dir, output_dir=output_dir, channel=channel) return None if __name__ == "__main__": parser = argparse.ArgumentParser("run_wpe") parser.add_argument( "--wav-scp", type=str, required=True, help="Path pf wav scp file", ) parser.add_argument( "--audio-dir", type=str, required=True, help="Directory of input audio files", ) parser.add_argument( "--output-dir", type=str, required=True, help="Output directory of WPE enhanced audio files", ) parser.add_argument( "--channel", type=str, required=True, help="Channel number of input audio", ) parser.add_argument("--nj", type=int, default="1", help="number of process") args = parser.parse_args() # logging info logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" logging.basicConfig(level=logging.INFO, format=logfmt) logging.info("wavfile={}".format(args.wav_scp)) logging.info("processingnum={}".format(args.nj)) wpe_manager( wav_scp=args.wav_scp, processing_num=args.nj, audio_dir=args.audio_dir, output_dir=args.output_dir, channel=int(args.channel) )