wuhongsheng
2024-07-05 3a4281f4959534b1bf5d01acf0085f4f8e6f2ec8
funasr/models/eend/utils/feature.py
@@ -8,13 +8,13 @@
def get_input_dim(
        frame_size,
        context_size,
        transform_type,
    frame_size,
    context_size,
    transform_type,
):
    if transform_type.startswith('logmel23'):
    if transform_type.startswith("logmel23"):
        frame_size = 23
    elif transform_type.startswith('logmel'):
    elif transform_type.startswith("logmel"):
        frame_size = 40
    else:
        fft_size = 1 << (frame_size - 1).bit_length()
@@ -23,11 +23,8 @@
    return input_dim
def transform(
        Y,
        transform_type=None,
        dtype=np.float32):
    """ Transform STFT feature
def transform(Y, transform_type=None, dtype=np.float32):
    """Transform STFT feature
    Args:
        Y: STFT
@@ -42,37 +39,37 @@
    Y = np.abs(Y)
    if not transform_type:
        pass
    elif transform_type == 'log':
    elif transform_type == "log":
        Y = np.log(np.maximum(Y, 1e-10))
    elif transform_type == 'logmel':
    elif transform_type == "logmel":
        n_fft = 2 * (Y.shape[1] - 1)
        sr = 16000
        n_mels = 40
        mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
        Y = np.dot(Y ** 2, mel_basis.T)
        Y = np.dot(Y**2, mel_basis.T)
        Y = np.log10(np.maximum(Y, 1e-10))
    elif transform_type == 'logmel23':
    elif transform_type == "logmel23":
        n_fft = 2 * (Y.shape[1] - 1)
        sr = 8000
        n_mels = 23
        mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
        Y = np.dot(Y ** 2, mel_basis.T)
        Y = np.dot(Y**2, mel_basis.T)
        Y = np.log10(np.maximum(Y, 1e-10))
    elif transform_type == 'logmel23_mn':
    elif transform_type == "logmel23_mn":
        n_fft = 2 * (Y.shape[1] - 1)
        sr = 8000
        n_mels = 23
        mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
        Y = np.dot(Y ** 2, mel_basis.T)
        Y = np.dot(Y**2, mel_basis.T)
        Y = np.log10(np.maximum(Y, 1e-10))
        mean = np.mean(Y, axis=0)
        Y = Y - mean
    elif transform_type == 'logmel23_swn':
    elif transform_type == "logmel23_swn":
        n_fft = 2 * (Y.shape[1] - 1)
        sr = 8000
        n_mels = 23
        mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
        Y = np.dot(Y ** 2, mel_basis.T)
        Y = np.dot(Y**2, mel_basis.T)
        Y = np.log10(np.maximum(Y, 1e-10))
        # b = np.ones(300)/300
        # mean = scipy.signal.convolve2d(Y, b[:, None], mode='same')
@@ -84,32 +81,31 @@
            th = (np.mean(powers[powers >= th]) + np.mean(powers[powers < th])) / 2
        mean = np.mean(Y[powers > th, :], axis=0)
        Y = Y - mean
    elif transform_type == 'logmel23_mvn':
    elif transform_type == "logmel23_mvn":
        n_fft = 2 * (Y.shape[1] - 1)
        sr = 8000
        n_mels = 23
        mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
        Y = np.dot(Y ** 2, mel_basis.T)
        Y = np.dot(Y**2, mel_basis.T)
        Y = np.log10(np.maximum(Y, 1e-10))
        mean = np.mean(Y, axis=0)
        Y = Y - mean
        std = np.maximum(np.std(Y, axis=0), 1e-10)
        Y = Y / std
    else:
        raise ValueError('Unknown transform_type: %s' % transform_type)
        raise ValueError("Unknown transform_type: %s" % transform_type)
    return Y.astype(dtype)
def subsample(Y, T, subsampling=1):
    """ Frame subsampling
    """
    """Frame subsampling"""
    Y_ss = Y[::subsampling]
    T_ss = T[::subsampling]
    return Y_ss, T_ss
def splice(Y, context_size=0):
    """ Frame splicing
    """Frame splicing
    Args:
        Y: feature
@@ -122,22 +118,18 @@
        Y_spliced: spliced feature
            (n_frames, n_featdim * (2 * context_size + 1))-shaped
    """
    Y_pad = np.pad(
        Y,
        [(context_size, context_size), (0, 0)],
        'constant')
    Y_pad = np.pad(Y, [(context_size, context_size), (0, 0)], "constant")
    Y_spliced = np.lib.stride_tricks.as_strided(
        np.ascontiguousarray(Y_pad),
        (Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
        (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
        (Y.itemsize * Y.shape[1], Y.itemsize),
        writeable=False,
    )
    return Y_spliced
def stft(
        data,
        frame_size=1024,
        frame_shift=256):
    """ Compute STFT features
def stft(data, frame_size=1024, frame_shift=256):
    """Compute STFT features
    Args:
        data: audio signal
@@ -154,11 +146,11 @@
    # HACK: The last frame is ommited
    #       as librosa.stft produces such an excessive frame
    if len(data) % frame_shift == 0:
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
                            hop_length=frame_shift).T[:-1]
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift).T[
            :-1
        ]
    else:
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
                            hop_length=frame_shift).T
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift).T
def _count_frames(data_len, size, shift):
@@ -170,14 +162,9 @@
def get_frame_labels(
        kaldi_obj,
        rec,
        start=0,
        end=None,
        frame_size=1024,
        frame_shift=256,
        n_speakers=None):
    """ Get frame-aligned labels of given recording
    kaldi_obj, rec, start=0, end=None, frame_size=1024, frame_shift=256, n_speakers=None
):
    """Get frame-aligned labels of given recording
    Args:
        kaldi_obj (KaldiData)
        rec (str): recording id
@@ -192,10 +179,8 @@
        T: label
            (n_frames, n_speakers)-shaped np.int32 array
    """
    filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
    speakers = np.unique(
        [kaldi_obj.utt2spk[seg['utt']] for seg
         in filtered_segments]).tolist()
    filtered_segments = kaldi_obj.segments[kaldi_obj.segments["rec"] == rec]
    speakers = np.unique([kaldi_obj.utt2spk[seg["utt"]] for seg in filtered_segments]).tolist()
    if n_speakers is None:
        n_speakers = len(speakers)
    es = end * frame_shift if end is not None else None
@@ -206,11 +191,9 @@
        end = n_frames
    for seg in filtered_segments:
        speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
        start_frame = np.rint(
            seg['st'] * rate / frame_shift).astype(int)
        end_frame = np.rint(
            seg['et'] * rate / frame_shift).astype(int)
        speaker_index = speakers.index(kaldi_obj.utt2spk[seg["utt"]])
        start_frame = np.rint(seg["st"] * rate / frame_shift).astype(int)
        end_frame = np.rint(seg["et"] * rate / frame_shift).astype(int)
        rel_start = rel_end = None
        if start <= start_frame and start_frame < end:
            rel_start = start_frame - start
@@ -222,11 +205,9 @@
def get_labeledSTFT(
        kaldi_obj,
        rec, start, end, frame_size, frame_shift,
        n_speakers=None,
        use_speaker_id=False):
    """ Extracts STFT and corresponding labels
    kaldi_obj, rec, start, end, frame_size, frame_shift, n_speakers=None, use_speaker_id=False
):
    """Extracts STFT and corresponding labels
    Extracts STFT and corresponding diarization labels for
    given recording id and start/end times
@@ -246,14 +227,11 @@
        T: label
            (n_frmaes, n_speakers)-shaped np.int32 array.
    """
    data, rate = kaldi_obj.load_wav(
        rec, start * frame_shift, end * frame_shift)
    data, rate = kaldi_obj.load_wav(rec, start * frame_shift, end * frame_shift)
    Y = stft(data, frame_size, frame_shift)
    filtered_segments = kaldi_obj.segments[rec]
    # filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
    speakers = np.unique(
        [kaldi_obj.utt2spk[seg['utt']] for seg
         in filtered_segments]).tolist()
    speakers = np.unique([kaldi_obj.utt2spk[seg["utt"]] for seg in filtered_segments]).tolist()
    if n_speakers is None:
        n_speakers = len(speakers)
    T = np.zeros((Y.shape[0], n_speakers), dtype=np.int32)
@@ -263,13 +241,11 @@
        S = np.zeros((Y.shape[0], len(all_speakers)), dtype=np.int32)
    for seg in filtered_segments:
        speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
        speaker_index = speakers.index(kaldi_obj.utt2spk[seg["utt"]])
        if use_speaker_id:
            all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg['utt']])
        start_frame = np.rint(
            seg['st'] * rate / frame_shift).astype(int)
        end_frame = np.rint(
            seg['et'] * rate / frame_shift).astype(int)
            all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg["utt"]])
        start_frame = np.rint(seg["st"] * rate / frame_shift).astype(int)
        end_frame = np.rint(seg["et"] * rate / frame_shift).astype(int)
        rel_start = rel_end = None
        if start <= start_frame and start_frame < end:
            rel_start = start_frame - start