import io import functools import logging # import soundfile as sf import numpy as np import matplotlib import matplotlib.pylab as plt # from IPython.display import display, Audio from nara_wpe.utils import stft, istft from pb_bss.distribution import CACGMMTrainer from pb_bss.evaluation import InputMetrics, OutputMetrics from dataclasses import dataclass # from beamforming_wrapper import beamform_mvdr_souden_from_masks from pb_chime5.utils.numpy_utils import segment_axis_v2 from textgrid_processor import read_textgrid_from_file def get_time_activity(dur_list, wavlen, sr): time_activity = [False] * wavlen for dur in dur_list: xmax = int(dur[1] * sr) xmin = int(dur[0] * sr) if xmax > wavlen: continue for i in range(xmin, xmax): time_activity[i] = True logging.info("Num of actived samples {}".format(time_activity.count(True))) return time_activity def get_frequency_activity( time_activity, stft_window_length, stft_shift, stft_fading=True, stft_pad=True, ): time_activity = np.asarray(time_activity) if stft_fading: pad_width = np.array([(0, 0)] * time_activity.ndim) pad_width[-1, :] = stft_window_length - stft_shift # Consider fading time_activity = np.pad(time_activity, pad_width, mode="constant") return segment_axis_v2( time_activity, length=stft_window_length, shift=stft_shift, end="pad" if stft_pad else "cut", ).any(axis=-1) @dataclass class Beamformer: type: str postfilter: str def __call__(self, Obs, target_mask, distortion_mask, debug=False): bf = self.type if bf == "mvdrSouden_ban": from pb_chime5.speech_enhancement.beamforming_wrapper import ( beamform_mvdr_souden_from_masks, ) X_hat = beamform_mvdr_souden_from_masks( Y=Obs, X_mask=target_mask, N_mask=distortion_mask, ban=True, ) elif bf == "ch0": X_hat = Obs[0] elif bf == "sum": X_hat = np.sum(Obs, axis=0) else: raise NotImplementedError(bf) if self.postfilter is None: pass elif self.postfilter == "mask_mul": X_hat = X_hat * target_mask else: raise NotImplementedError(self.postfilter) return X_hat @dataclass class GSS: iterations: int = 20 iterations_post: int = 0 verbose: bool = True # use_pinv: bool = False # stable: bool = True def __call__(self, Obs, acitivity_freq=None, debug=False): initialization = np.asarray(acitivity_freq, dtype=np.float64) initialization = np.where(initialization == 0, 1e-10, initialization) initialization = initialization / np.sum(initialization, keepdims=True, axis=0) initialization = np.repeat(initialization[None, ...], 257, axis=0) source_active_mask = np.asarray(acitivity_freq, dtype=bool) source_active_mask = np.repeat(source_active_mask[None, ...], 257, axis=0) cacGMM = CACGMMTrainer() if debug: learned = [] all_affiliations = [] F = Obs.shape[-1] T = Obs.T.shape[-2] for f in range(F): if self.verbose: if f % 50 == 0: logging.info(f"{f}/{F}") # T: Consider end of signal. # This should not be nessesary, but activity is for inear and not for # array. cur = cacGMM.fit( y=Obs.T[f, ...], initialization=initialization[f, ..., :T], iterations=self.iterations, source_activity_mask=source_active_mask[f, ..., :T], ) affiliation = cur.predict( Obs.T[f, ...], source_activity_mask=source_active_mask[f, ..., :T], ) all_affiliations.append(affiliation) posterior = np.array(all_affiliations).transpose(1, 2, 0) return posterior