嘉渊
2023-07-20 21536068b9e1d94a3c0de09b6b166a786f98361f
update
1个文件已修改
3个文件已添加
569 ■■■■■ 已修改文件
egs/callhome/eend_ola/local/dump_feature.py 127 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/callhome/eend_ola/local/split.py 117 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/callhome/eend_ola/run.sh 39 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/eend_ola/utils/feature.py 286 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/callhome/eend_ola/local/dump_feature.py
New file
@@ -0,0 +1,127 @@
import argparse
import os
import numpy as np
import funasr.modules.eend_ola.utils.feature as feature
import funasr.modules.eend_ola.utils.kaldi_data as kaldi_data
def _count_frames(data_len, size, step):
    return int((data_len - size + step) / step)
def _gen_frame_indices(
        data_length, size=2000, step=2000,
        use_last_samples=False,
        label_delay=0,
        subsampling=1):
    i = -1
    for i in range(_count_frames(data_length, size, step)):
        yield i * step, i * step + size
    if use_last_samples and i * step + size < data_length:
        if data_length - (i + 1) * step - subsampling * label_delay > 0:
            yield (i + 1) * step, data_length
class KaldiDiarizationDataset():
    def __init__(
            self,
            data_dir,
            chunk_size=2000,
            context_size=0,
            frame_size=1024,
            frame_shift=256,
            subsampling=1,
            rate=16000,
            input_transform=None,
            use_last_samples=False,
            label_delay=0,
            n_speakers=None,
    ):
        self.data_dir = data_dir
        self.chunk_size = chunk_size
        self.context_size = context_size
        self.frame_size = frame_size
        self.frame_shift = frame_shift
        self.subsampling = subsampling
        self.input_transform = input_transform
        self.n_speakers = n_speakers
        self.chunk_indices = []
        self.label_delay = label_delay
        self.data = kaldi_data.KaldiData(self.data_dir)
        # make chunk indices: filepath, start_frame, end_frame
        for rec, path in self.data.wavs.items():
            data_len = int(self.data.reco2dur[rec] * rate / frame_shift)
            data_len = int(data_len / self.subsampling)
            for st, ed in _gen_frame_indices(
                    data_len, chunk_size, chunk_size, use_last_samples,
                    label_delay=self.label_delay,
                    subsampling=self.subsampling):
                self.chunk_indices.append(
                    (rec, path, st * self.subsampling, ed * self.subsampling))
        print(len(self.chunk_indices), " chunks")
def convert(args):
    f = open(out_wav_file, 'w')
    dataset = KaldiDiarizationDataset(
        data_dir=args.data_dir,
        chunk_size=args.num_frames,
        context_size=args.context_size,
        input_transform=args.input_transform,
        frame_size=args.frame_size,
        frame_shift=args.frame_shift,
        subsampling=args.subsampling,
        rate=8000,
        use_last_samples=True,
    )
    length = len(dataset.chunk_indices)
    for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices):
        Y, T = feature.get_labeledSTFT(
            dataset.data,
            rec,
            st,
            ed,
            dataset.frame_size,
            dataset.frame_shift,
            dataset.n_speakers)
        Y = feature.transform(Y, dataset.input_transform)
        Y_spliced = feature.splice(Y, dataset.context_size)
        Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling)
        st = '{:0>7d}'.format(st)
        ed = '{:0>7d}'.format(ed)
        suffix = '_' + st + '_' + ed
        parts = os.readlink('/'.join(path.split('/')[:-1])).split('/')
        # print('parts: ', parts)
        parts = parts[:4] + ['numpy_data'] + parts[4:]
        cur_path = '/'.join(parts)
        # print('cur path: ', cur_path)
        out_path = os.path.join(cur_path, path.split('/')[-1].split('.')[0] + suffix + '.npz')
        # print(out_path)
        # print(cur_path)
        if not os.path.exists(cur_path):
            os.makedirs(cur_path)
        np.savez(out_path, Y=Y_ss, T=T_ss)
        if idx == length - 1:
            f.write(rec + suffix + ' ' + out_path)
        else:
            f.write(rec + suffix + ' ' + out_path + '\n')
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("data_dir")
    parser.add_argument("num_frames")
    parser.add_argument("context_size")
    parser.add_argument("frame_size")
    parser.add_argument("frame_shift")
    parser.add_argument("subsampling")
    args = parser.parse_args()
    convert(args)
egs/callhome/eend_ola/local/split.py
New file
@@ -0,0 +1,117 @@
import argparse
import os
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('root_path', help='raw data path')
    args = parser.parse_args()
    root_path = args.root_path
    work_path = os.path.join(root_path, ".work")
    scp_files = os.listdir(work_path)
    reco2dur_dict = {}
    with open(root_path + 'reco2dur') as f:
        lines = f.readlines()
        for line in lines:
            parts = line.strip().split()
            reco2dur_dict[parts[0]] = parts[1]
    spk2utt_dict = {}
    with open(root_path + 'spk2utt') as f:
        lines = f.readlines()
        for line in lines:
            parts = line.strip().split()
            spk = parts[0]
            utts = parts[1:]
            for utt in utts:
                tmp = utt.split('data')
                rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2])
                if rec in spk2utt_dict.keys():
                    spk2utt_dict[rec].append((spk, utt))
                else:
                    spk2utt_dict[rec] = []
                    spk2utt_dict[rec].append((spk, utt))
    segment_dict = {}
    with open(root_path + 'segments') as f:
        lines = f.readlines()
        for line in lines:
            parts = line.strip().split()
            if parts[1] in segment_dict.keys():
                segment_dict[parts[1]].append((parts[0], parts[2], parts[3]))
            else:
                segment_dict[parts[1]] = []
                segment_dict[parts[1]].append((parts[0], parts[2], parts[3]))
    utt2spk_dict = {}
    with open(root_path + 'utt2spk') as f:
        lines = f.readlines()
        for line in lines:
            parts = line.strip().split()
            utt = parts[0]
            tmp = utt.split('data')
            rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2])
            if rec in utt2spk_dict.keys():
                utt2spk_dict[rec].append((parts[0], parts[1]))
            else:
                utt2spk_dict[rec] = []
                utt2spk_dict[rec].append((parts[0], parts[1]))
    for file in scp_files:
        scp_file = work_path + file
        idx = scp_file.split('.')[-2]
        reco2dur_file = work_path + 'reco2dur.' + idx
        spk2utt_file = work_path + 'spk2utt.' + idx
        segment_file = work_path + 'segments.' + idx
        utt2spk_file = work_path + 'utt2spk.' + idx
        fpp = open(scp_file)
        scp_lines = fpp.readlines()
        keys = []
        for line in scp_lines:
            name = line.strip().split()[0]
            keys.append(name)
        with open(reco2dur_file, 'w') as f:
            lines = []
            for key in keys:
                string = key + ' ' + reco2dur_dict[key]
                lines.append(string + '\n')
            lines[-1] = lines[-1][:-1]
            f.writelines(lines)
        with open(spk2utt_file, 'w') as f:
            lines = []
            for key in keys:
                items = spk2utt_dict[key]
                for item in items:
                    string = item[0]
                    for it in item[1:]:
                        string += ' '
                        string += it
                    lines.append(string + '\n')
            lines[-1] = lines[-1][:-1]
            f.writelines(lines)
        with open(segment_file, 'w') as f:
            lines = []
            for key in keys:
                items = segment_dict[key]
                for item in items:
                    string = item[0] + ' ' + key + ' ' + item[1] + ' ' + item[2]
                    lines.append(string + '\n')
            lines[-1] = lines[-1][:-1]
            f.writelines(lines)
        with open(utt2spk_file, 'w') as f:
            lines = []
            for key in keys:
                items = utt2spk_dict[key]
                for item in items:
                    string = item[0] + ' ' + item[1]
                    lines.append(string + '\n')
            lines[-1] = lines[-1][:-1]
            f.writelines(lines)
        fpp.close()
egs/callhome/eend_ola/run.sh
@@ -8,6 +8,11 @@
count=1
# general configuration
dump_cmd=utils/run.pl
nj=64
# feature configuration
data_dir="./data"
simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data"
simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data"
callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data"
@@ -62,13 +67,33 @@
    local/run_prepare_shared_eda.sh
fi
## Prepare data for training and inference
#if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
#    echo "stage 0: Prepare data for training and inference"
#    echo "The detail can be found in https://github.com/hitachi-speech/EEND"
#    . ./local/
#fi
#
# Prepare data for training and inference
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    echo "stage 0: Prepare data for training and inference"
    simu_opts_num_speaker_array=(1 2 3 4)
    simu_opts_sil_scale_array=(2 2 5 9)
    simu_opts_num_speaker=${simu_opts_num_speaker_array[i]}
    simu_opts_sil_scale=${simu_opts_sil_scale_array[i]}
    simu_opts_num_train=100000
    # for simulated data of chunk500
    for dset in swb_sre_tr swb_sre_cv; do
        if [ "$dset" == "swb_sre_tr" ]; then
            n_mixtures=${simu_opts_num_train}
        else
            n_mixtures=500
        fi
        simu_data_dir=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures}
        mkdir ${data_dir}/simu/data/${simu_data_dir}/.work
        split_scps=
        for n in $(seq $nj); do
            split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp"
        done
        utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1
        python local/split.py ${data_dir}/simu/data/${simu_data_dir}
    done
fi
# Training on simulated two-speaker data
world_size=$gpu_num
funasr/modules/eend_ola/utils/feature.py
New file
@@ -0,0 +1,286 @@
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
# Licensed under the MIT license.
#
# This module is for computing audio features
import numpy as np
import librosa
def get_input_dim(
        frame_size,
        context_size,
        transform_type,
):
    if transform_type.startswith('logmel23'):
        frame_size = 23
    elif transform_type.startswith('logmel'):
        frame_size = 40
    else:
        fft_size = 1 << (frame_size - 1).bit_length()
        frame_size = int(fft_size / 2) + 1
    input_dim = (2 * context_size + 1) * frame_size
    return input_dim
def transform(
        Y,
        transform_type=None,
        dtype=np.float32):
    """ Transform STFT feature
    Args:
        Y: STFT
            (n_frames, n_bins)-shaped np.complex array
        transform_type:
            None, "log"
        dtype: output data type
            np.float32 is expected
    Returns:
        Y (numpy.array): transformed feature
    """
    Y = np.abs(Y)
    if not transform_type:
        pass
    elif transform_type == 'log':
        Y = np.log(np.maximum(Y, 1e-10))
    elif transform_type == 'logmel':
        n_fft = 2 * (Y.shape[1] - 1)
        sr = 16000
        n_mels = 40
        mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
        Y = np.dot(Y ** 2, mel_basis.T)
        Y = np.log10(np.maximum(Y, 1e-10))
    elif transform_type == 'logmel23':
        n_fft = 2 * (Y.shape[1] - 1)
        sr = 8000
        n_mels = 23
        mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
        Y = np.dot(Y ** 2, mel_basis.T)
        Y = np.log10(np.maximum(Y, 1e-10))
    elif transform_type == 'logmel23_mn':
        n_fft = 2 * (Y.shape[1] - 1)
        sr = 8000
        n_mels = 23
        mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
        Y = np.dot(Y ** 2, mel_basis.T)
        Y = np.log10(np.maximum(Y, 1e-10))
        mean = np.mean(Y, axis=0)
        Y = Y - mean
    elif transform_type == 'logmel23_swn':
        n_fft = 2 * (Y.shape[1] - 1)
        sr = 8000
        n_mels = 23
        mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
        Y = np.dot(Y ** 2, mel_basis.T)
        Y = np.log10(np.maximum(Y, 1e-10))
        # b = np.ones(300)/300
        # mean = scipy.signal.convolve2d(Y, b[:, None], mode='same')
        #  simple 2-means based threshoding for mean calculation
        powers = np.sum(Y, axis=1)
        th = (np.max(powers) + np.min(powers)) / 2.0
        for i in range(10):
            th = (np.mean(powers[powers >= th]) + np.mean(powers[powers < th])) / 2
        mean = np.mean(Y[powers > th, :], axis=0)
        Y = Y - mean
    elif transform_type == 'logmel23_mvn':
        n_fft = 2 * (Y.shape[1] - 1)
        sr = 8000
        n_mels = 23
        mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
        Y = np.dot(Y ** 2, mel_basis.T)
        Y = np.log10(np.maximum(Y, 1e-10))
        mean = np.mean(Y, axis=0)
        Y = Y - mean
        std = np.maximum(np.std(Y, axis=0), 1e-10)
        Y = Y / std
    else:
        raise ValueError('Unknown transform_type: %s' % transform_type)
    return Y.astype(dtype)
def subsample(Y, T, subsampling=1):
    """ Frame subsampling
    """
    Y_ss = Y[::subsampling]
    T_ss = T[::subsampling]
    return Y_ss, T_ss
def splice(Y, context_size=0):
    """ Frame splicing
    Args:
        Y: feature
            (n_frames, n_featdim)-shaped numpy array
        context_size:
            number of frames concatenated on left-side
            if context_size = 5, 11 frames are concatenated.
    Returns:
        Y_spliced: spliced feature
            (n_frames, n_featdim * (2 * context_size + 1))-shaped
    """
    Y_pad = np.pad(
        Y,
        [(context_size, context_size), (0, 0)],
        'constant')
    Y_spliced = np.lib.stride_tricks.as_strided(
        np.ascontiguousarray(Y_pad),
        (Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
        (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
    return Y_spliced
def stft(
        data,
        frame_size=1024,
        frame_shift=256):
    """ Compute STFT features
    Args:
        data: audio signal
            (n_samples,)-shaped np.float32 array
        frame_size: number of samples in a frame (must be a power of two)
        frame_shift: number of samples between frames
    Returns:
        stft: STFT frames
            (n_frames, n_bins)-shaped np.complex64 array
    """
    # round up to nearest power of 2
    fft_size = 1 << (frame_size - 1).bit_length()
    # HACK: The last frame is ommited
    #       as librosa.stft produces such an excessive frame
    if len(data) % frame_shift == 0:
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
                            hop_length=frame_shift).T[:-1]
    else:
        return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
                            hop_length=frame_shift).T
def _count_frames(data_len, size, shift):
    # HACK: Assuming librosa.stft(..., center=True)
    n_frames = 1 + int(data_len / shift)
    if data_len % shift == 0:
        n_frames = n_frames - 1
    return n_frames
def get_frame_labels(
        kaldi_obj,
        rec,
        start=0,
        end=None,
        frame_size=1024,
        frame_shift=256,
        n_speakers=None):
    """ Get frame-aligned labels of given recording
    Args:
        kaldi_obj (KaldiData)
        rec (str): recording id
        start (int): start frame index
        end (int): end frame index
            None means the last frame of recording
        frame_size (int): number of frames in a frame
        frame_shift (int): number of shift samples
        n_speakers (int): number of speakers
            if None, the value is given from data
    Returns:
        T: label
            (n_frames, n_speakers)-shaped np.int32 array
    """
    filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
    speakers = np.unique(
        [kaldi_obj.utt2spk[seg['utt']] for seg
         in filtered_segments]).tolist()
    if n_speakers is None:
        n_speakers = len(speakers)
    es = end * frame_shift if end is not None else None
    data, rate = kaldi_obj.load_wav(rec, start * frame_shift, es)
    n_frames = _count_frames(len(data), frame_size, frame_shift)
    T = np.zeros((n_frames, n_speakers), dtype=np.int32)
    if end is None:
        end = n_frames
    for seg in filtered_segments:
        speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
        start_frame = np.rint(
            seg['st'] * rate / frame_shift).astype(int)
        end_frame = np.rint(
            seg['et'] * rate / frame_shift).astype(int)
        rel_start = rel_end = None
        if start <= start_frame and start_frame < end:
            rel_start = start_frame - start
        if start < end_frame and end_frame <= end:
            rel_end = end_frame - start
        if rel_start is not None or rel_end is not None:
            T[rel_start:rel_end, speaker_index] = 1
    return T
def get_labeledSTFT(
        kaldi_obj,
        rec, start, end, frame_size, frame_shift,
        n_speakers=None,
        use_speaker_id=False):
    """ Extracts STFT and corresponding labels
    Extracts STFT and corresponding diarization labels for
    given recording id and start/end times
    Args:
        kaldi_obj (KaldiData)
        rec (str): recording id
        start (int): start frame index
        end (int): end frame index
        frame_size (int): number of samples in a frame
        frame_shift (int): number of shift samples
        n_speakers (int): number of speakers
            if None, the value is given from data
    Returns:
        Y: STFT
            (n_frames, n_bins)-shaped np.complex64 array,
        T: label
            (n_frmaes, n_speakers)-shaped np.int32 array.
    """
    data, rate = kaldi_obj.load_wav(
        rec, start * frame_shift, end * frame_shift)
    Y = stft(data, frame_size, frame_shift)
    filtered_segments = kaldi_obj.segments[rec]
    # filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
    speakers = np.unique(
        [kaldi_obj.utt2spk[seg['utt']] for seg
         in filtered_segments]).tolist()
    if n_speakers is None:
        n_speakers = len(speakers)
    T = np.zeros((Y.shape[0], n_speakers), dtype=np.int32)
    if use_speaker_id:
        all_speakers = sorted(kaldi_obj.spk2utt.keys())
        S = np.zeros((Y.shape[0], len(all_speakers)), dtype=np.int32)
    for seg in filtered_segments:
        speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
        if use_speaker_id:
            all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg['utt']])
        start_frame = np.rint(
            seg['st'] * rate / frame_shift).astype(int)
        end_frame = np.rint(
            seg['et'] * rate / frame_shift).astype(int)
        rel_start = rel_end = None
        if start <= start_frame and start_frame < end:
            rel_start = start_frame - start
        if start < end_frame and end_frame <= end:
            rel_end = end_frame - start
        if rel_start is not None or rel_end is not None:
            T[rel_start:rel_end, speaker_index] = 1
            if use_speaker_id:
                S[rel_start:rel_end, all_speaker_index] = 1
    if use_speaker_id:
        return Y, T, S
    else:
        return Y, T