From 21536068b9e1d94a3c0de09b6b166a786f98361f Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 20 七月 2023 17:09:45 +0800
Subject: [PATCH] update
---
egs/callhome/eend_ola/run.sh | 39 +++
funasr/modules/eend_ola/utils/feature.py | 286 ++++++++++++++++++++++++++++
egs/callhome/eend_ola/local/split.py | 117 +++++++++++
egs/callhome/eend_ola/local/dump_feature.py | 127 ++++++++++++
4 files changed, 562 insertions(+), 7 deletions(-)
diff --git a/egs/callhome/eend_ola/local/dump_feature.py b/egs/callhome/eend_ola/local/dump_feature.py
new file mode 100644
index 0000000..169615e
--- /dev/null
+++ b/egs/callhome/eend_ola/local/dump_feature.py
@@ -0,0 +1,127 @@
+import argparse
+import os
+
+import numpy as np
+
+import funasr.modules.eend_ola.utils.feature as feature
+import funasr.modules.eend_ola.utils.kaldi_data as kaldi_data
+
+
+def _count_frames(data_len, size, step):
+ return int((data_len - size + step) / step)
+
+
+def _gen_frame_indices(
+ data_length, size=2000, step=2000,
+ use_last_samples=False,
+ label_delay=0,
+ subsampling=1):
+ i = -1
+ for i in range(_count_frames(data_length, size, step)):
+ yield i * step, i * step + size
+ if use_last_samples and i * step + size < data_length:
+ if data_length - (i + 1) * step - subsampling * label_delay > 0:
+ yield (i + 1) * step, data_length
+
+
+class KaldiDiarizationDataset():
+ def __init__(
+ self,
+ data_dir,
+ chunk_size=2000,
+ context_size=0,
+ frame_size=1024,
+ frame_shift=256,
+ subsampling=1,
+ rate=16000,
+ input_transform=None,
+ use_last_samples=False,
+ label_delay=0,
+ n_speakers=None,
+ ):
+ self.data_dir = data_dir
+ self.chunk_size = chunk_size
+ self.context_size = context_size
+ self.frame_size = frame_size
+ self.frame_shift = frame_shift
+ self.subsampling = subsampling
+ self.input_transform = input_transform
+ self.n_speakers = n_speakers
+ self.chunk_indices = []
+ self.label_delay = label_delay
+
+ self.data = kaldi_data.KaldiData(self.data_dir)
+
+ # make chunk indices: filepath, start_frame, end_frame
+ for rec, path in self.data.wavs.items():
+ data_len = int(self.data.reco2dur[rec] * rate / frame_shift)
+ data_len = int(data_len / self.subsampling)
+ for st, ed in _gen_frame_indices(
+ data_len, chunk_size, chunk_size, use_last_samples,
+ label_delay=self.label_delay,
+ subsampling=self.subsampling):
+ self.chunk_indices.append(
+ (rec, path, st * self.subsampling, ed * self.subsampling))
+ print(len(self.chunk_indices), " chunks")
+
+
+def convert(args):
+ f = open(out_wav_file, 'w')
+ dataset = KaldiDiarizationDataset(
+ data_dir=args.data_dir,
+ chunk_size=args.num_frames,
+ context_size=args.context_size,
+ input_transform=args.input_transform,
+ frame_size=args.frame_size,
+ frame_shift=args.frame_shift,
+ subsampling=args.subsampling,
+ rate=8000,
+ use_last_samples=True,
+ )
+ length = len(dataset.chunk_indices)
+ for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices):
+ Y, T = feature.get_labeledSTFT(
+ dataset.data,
+ rec,
+ st,
+ ed,
+ dataset.frame_size,
+ dataset.frame_shift,
+ dataset.n_speakers)
+ Y = feature.transform(Y, dataset.input_transform)
+ Y_spliced = feature.splice(Y, dataset.context_size)
+ Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling)
+ st = '{:0>7d}'.format(st)
+ ed = '{:0>7d}'.format(ed)
+ suffix = '_' + st + '_' + ed
+
+ parts = os.readlink('/'.join(path.split('/')[:-1])).split('/')
+ # print('parts: ', parts)
+ parts = parts[:4] + ['numpy_data'] + parts[4:]
+ cur_path = '/'.join(parts)
+ # print('cur path: ', cur_path)
+ out_path = os.path.join(cur_path, path.split('/')[-1].split('.')[0] + suffix + '.npz')
+ # print(out_path)
+ # print(cur_path)
+ if not os.path.exists(cur_path):
+ os.makedirs(cur_path)
+ np.savez(out_path, Y=Y_ss, T=T_ss)
+ if idx == length - 1:
+ f.write(rec + suffix + ' ' + out_path)
+ else:
+ f.write(rec + suffix + ' ' + out_path + '\n')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("data_dir")
+ parser.add_argument("num_frames")
+ parser.add_argument("context_size")
+ parser.add_argument("frame_size")
+ parser.add_argument("frame_shift")
+ parser.add_argument("subsampling")
+
+
+
+ args = parser.parse_args()
+ convert(args)
diff --git a/egs/callhome/eend_ola/local/split.py b/egs/callhome/eend_ola/local/split.py
new file mode 100644
index 0000000..6f313cc
--- /dev/null
+++ b/egs/callhome/eend_ola/local/split.py
@@ -0,0 +1,117 @@
+import argparse
+import os
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('root_path', help='raw data path')
+ args = parser.parse_args()
+
+ root_path = args.root_path
+ work_path = os.path.join(root_path, ".work")
+ scp_files = os.listdir(work_path)
+
+ reco2dur_dict = {}
+ with open(root_path + 'reco2dur') as f:
+ lines = f.readlines()
+ for line in lines:
+ parts = line.strip().split()
+ reco2dur_dict[parts[0]] = parts[1]
+
+ spk2utt_dict = {}
+ with open(root_path + 'spk2utt') as f:
+ lines = f.readlines()
+ for line in lines:
+ parts = line.strip().split()
+ spk = parts[0]
+ utts = parts[1:]
+ for utt in utts:
+ tmp = utt.split('data')
+ rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2])
+ if rec in spk2utt_dict.keys():
+ spk2utt_dict[rec].append((spk, utt))
+ else:
+ spk2utt_dict[rec] = []
+ spk2utt_dict[rec].append((spk, utt))
+
+ segment_dict = {}
+ with open(root_path + 'segments') as f:
+ lines = f.readlines()
+ for line in lines:
+ parts = line.strip().split()
+ if parts[1] in segment_dict.keys():
+ segment_dict[parts[1]].append((parts[0], parts[2], parts[3]))
+ else:
+ segment_dict[parts[1]] = []
+ segment_dict[parts[1]].append((parts[0], parts[2], parts[3]))
+
+ utt2spk_dict = {}
+ with open(root_path + 'utt2spk') as f:
+ lines = f.readlines()
+ for line in lines:
+ parts = line.strip().split()
+ utt = parts[0]
+ tmp = utt.split('data')
+ rec = 'data_' + '_'.join(tmp[1][1:].split('_')[:-2])
+ if rec in utt2spk_dict.keys():
+ utt2spk_dict[rec].append((parts[0], parts[1]))
+ else:
+ utt2spk_dict[rec] = []
+ utt2spk_dict[rec].append((parts[0], parts[1]))
+
+ for file in scp_files:
+ scp_file = work_path + file
+ idx = scp_file.split('.')[-2]
+ reco2dur_file = work_path + 'reco2dur.' + idx
+ spk2utt_file = work_path + 'spk2utt.' + idx
+ segment_file = work_path + 'segments.' + idx
+ utt2spk_file = work_path + 'utt2spk.' + idx
+
+ fpp = open(scp_file)
+ scp_lines = fpp.readlines()
+ keys = []
+ for line in scp_lines:
+ name = line.strip().split()[0]
+ keys.append(name)
+
+ with open(reco2dur_file, 'w') as f:
+ lines = []
+ for key in keys:
+ string = key + ' ' + reco2dur_dict[key]
+ lines.append(string + '\n')
+ lines[-1] = lines[-1][:-1]
+ f.writelines(lines)
+
+ with open(spk2utt_file, 'w') as f:
+ lines = []
+ for key in keys:
+ items = spk2utt_dict[key]
+ for item in items:
+ string = item[0]
+ for it in item[1:]:
+ string += ' '
+ string += it
+ lines.append(string + '\n')
+ lines[-1] = lines[-1][:-1]
+ f.writelines(lines)
+
+ with open(segment_file, 'w') as f:
+ lines = []
+ for key in keys:
+ items = segment_dict[key]
+ for item in items:
+ string = item[0] + ' ' + key + ' ' + item[1] + ' ' + item[2]
+ lines.append(string + '\n')
+ lines[-1] = lines[-1][:-1]
+ f.writelines(lines)
+
+ with open(utt2spk_file, 'w') as f:
+ lines = []
+ for key in keys:
+ items = utt2spk_dict[key]
+ for item in items:
+ string = item[0] + ' ' + item[1]
+ lines.append(string + '\n')
+ lines[-1] = lines[-1][:-1]
+ f.writelines(lines)
+
+ fpp.close()
diff --git a/egs/callhome/eend_ola/run.sh b/egs/callhome/eend_ola/run.sh
index 40fb041..cd246fe 100644
--- a/egs/callhome/eend_ola/run.sh
+++ b/egs/callhome/eend_ola/run.sh
@@ -8,6 +8,11 @@
count=1
# general configuration
+dump_cmd=utils/run.pl
+nj=64
+
+# feature configuration
+data_dir="./data"
simu_feats_dir="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data/data"
simu_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/simu_data_chunk2000/data"
callhome_feats_dir_chunk2000="/nfs/wangjiaming.wjm/EEND_ARK_DATA/dump/callhome_chunk2000/data"
@@ -62,13 +67,33 @@
local/run_prepare_shared_eda.sh
fi
-## Prepare data for training and inference
-#if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
-# echo "stage 0: Prepare data for training and inference"
-# echo "The detail can be found in https://github.com/hitachi-speech/EEND"
-# . ./local/
-#fi
-#
+# Prepare data for training and inference
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Prepare data for training and inference"
+ simu_opts_num_speaker_array=(1 2 3 4)
+ simu_opts_sil_scale_array=(2 2 5 9)
+ simu_opts_num_speaker=${simu_opts_num_speaker_array[i]}
+ simu_opts_sil_scale=${simu_opts_sil_scale_array[i]}
+ simu_opts_num_train=100000
+
+ # for simulated data of chunk500
+ for dset in swb_sre_tr swb_sre_cv; do
+ if [ "$dset" == "swb_sre_tr" ]; then
+ n_mixtures=${simu_opts_num_train}
+ else
+ n_mixtures=500
+ fi
+ simu_data_dir=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures}
+ mkdir ${data_dir}/simu/data/${simu_data_dir}/.work
+ split_scps=
+ for n in $(seq $nj); do
+ split_scps="$split_scps ${data_dir}/simu/data/${simu_data_dir}/.work/wav.$n.scp"
+ done
+ utils/split_scp.pl "${data_dir}/simu/data/${simu_data_dir}/wav.scp" $split_scps || exit 1
+ python local/split.py ${data_dir}/simu/data/${simu_data_dir}
+ done
+fi
+
# Training on simulated two-speaker data
world_size=$gpu_num
diff --git a/funasr/modules/eend_ola/utils/feature.py b/funasr/modules/eend_ola/utils/feature.py
new file mode 100644
index 0000000..544a352
--- /dev/null
+++ b/funasr/modules/eend_ola/utils/feature.py
@@ -0,0 +1,286 @@
+# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
+# Licensed under the MIT license.
+#
+# This module is for computing audio features
+
+import numpy as np
+import librosa
+
+
+def get_input_dim(
+ frame_size,
+ context_size,
+ transform_type,
+):
+ if transform_type.startswith('logmel23'):
+ frame_size = 23
+ elif transform_type.startswith('logmel'):
+ frame_size = 40
+ else:
+ fft_size = 1 << (frame_size - 1).bit_length()
+ frame_size = int(fft_size / 2) + 1
+ input_dim = (2 * context_size + 1) * frame_size
+ return input_dim
+
+
+def transform(
+ Y,
+ transform_type=None,
+ dtype=np.float32):
+ """ Transform STFT feature
+
+ Args:
+ Y: STFT
+ (n_frames, n_bins)-shaped np.complex array
+ transform_type:
+ None, "log"
+ dtype: output data type
+ np.float32 is expected
+ Returns:
+ Y (numpy.array): transformed feature
+ """
+ Y = np.abs(Y)
+ if not transform_type:
+ pass
+ elif transform_type == 'log':
+ Y = np.log(np.maximum(Y, 1e-10))
+ elif transform_type == 'logmel':
+ n_fft = 2 * (Y.shape[1] - 1)
+ sr = 16000
+ n_mels = 40
+ mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
+ Y = np.dot(Y ** 2, mel_basis.T)
+ Y = np.log10(np.maximum(Y, 1e-10))
+ elif transform_type == 'logmel23':
+ n_fft = 2 * (Y.shape[1] - 1)
+ sr = 8000
+ n_mels = 23
+ mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
+ Y = np.dot(Y ** 2, mel_basis.T)
+ Y = np.log10(np.maximum(Y, 1e-10))
+ elif transform_type == 'logmel23_mn':
+ n_fft = 2 * (Y.shape[1] - 1)
+ sr = 8000
+ n_mels = 23
+ mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
+ Y = np.dot(Y ** 2, mel_basis.T)
+ Y = np.log10(np.maximum(Y, 1e-10))
+ mean = np.mean(Y, axis=0)
+ Y = Y - mean
+ elif transform_type == 'logmel23_swn':
+ n_fft = 2 * (Y.shape[1] - 1)
+ sr = 8000
+ n_mels = 23
+ mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
+ Y = np.dot(Y ** 2, mel_basis.T)
+ Y = np.log10(np.maximum(Y, 1e-10))
+ # b = np.ones(300)/300
+ # mean = scipy.signal.convolve2d(Y, b[:, None], mode='same')
+
+ # simple 2-means based threshoding for mean calculation
+ powers = np.sum(Y, axis=1)
+ th = (np.max(powers) + np.min(powers)) / 2.0
+ for i in range(10):
+ th = (np.mean(powers[powers >= th]) + np.mean(powers[powers < th])) / 2
+ mean = np.mean(Y[powers > th, :], axis=0)
+ Y = Y - mean
+ elif transform_type == 'logmel23_mvn':
+ n_fft = 2 * (Y.shape[1] - 1)
+ sr = 8000
+ n_mels = 23
+ mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
+ Y = np.dot(Y ** 2, mel_basis.T)
+ Y = np.log10(np.maximum(Y, 1e-10))
+ mean = np.mean(Y, axis=0)
+ Y = Y - mean
+ std = np.maximum(np.std(Y, axis=0), 1e-10)
+ Y = Y / std
+ else:
+ raise ValueError('Unknown transform_type: %s' % transform_type)
+ return Y.astype(dtype)
+
+
+def subsample(Y, T, subsampling=1):
+ """ Frame subsampling
+ """
+ Y_ss = Y[::subsampling]
+ T_ss = T[::subsampling]
+ return Y_ss, T_ss
+
+
+def splice(Y, context_size=0):
+ """ Frame splicing
+
+ Args:
+ Y: feature
+ (n_frames, n_featdim)-shaped numpy array
+ context_size:
+ number of frames concatenated on left-side
+ if context_size = 5, 11 frames are concatenated.
+
+ Returns:
+ Y_spliced: spliced feature
+ (n_frames, n_featdim * (2 * context_size + 1))-shaped
+ """
+ Y_pad = np.pad(
+ Y,
+ [(context_size, context_size), (0, 0)],
+ 'constant')
+ Y_spliced = np.lib.stride_tricks.as_strided(
+ np.ascontiguousarray(Y_pad),
+ (Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
+ (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
+ return Y_spliced
+
+
+def stft(
+ data,
+ frame_size=1024,
+ frame_shift=256):
+ """ Compute STFT features
+
+ Args:
+ data: audio signal
+ (n_samples,)-shaped np.float32 array
+ frame_size: number of samples in a frame (must be a power of two)
+ frame_shift: number of samples between frames
+
+ Returns:
+ stft: STFT frames
+ (n_frames, n_bins)-shaped np.complex64 array
+ """
+ # round up to nearest power of 2
+ fft_size = 1 << (frame_size - 1).bit_length()
+ # HACK: The last frame is ommited
+ # as librosa.stft produces such an excessive frame
+ if len(data) % frame_shift == 0:
+ return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
+ hop_length=frame_shift).T[:-1]
+ else:
+ return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
+ hop_length=frame_shift).T
+
+
+def _count_frames(data_len, size, shift):
+ # HACK: Assuming librosa.stft(..., center=True)
+ n_frames = 1 + int(data_len / shift)
+ if data_len % shift == 0:
+ n_frames = n_frames - 1
+ return n_frames
+
+
+def get_frame_labels(
+ kaldi_obj,
+ rec,
+ start=0,
+ end=None,
+ frame_size=1024,
+ frame_shift=256,
+ n_speakers=None):
+ """ Get frame-aligned labels of given recording
+ Args:
+ kaldi_obj (KaldiData)
+ rec (str): recording id
+ start (int): start frame index
+ end (int): end frame index
+ None means the last frame of recording
+ frame_size (int): number of frames in a frame
+ frame_shift (int): number of shift samples
+ n_speakers (int): number of speakers
+ if None, the value is given from data
+ Returns:
+ T: label
+ (n_frames, n_speakers)-shaped np.int32 array
+ """
+ filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
+ speakers = np.unique(
+ [kaldi_obj.utt2spk[seg['utt']] for seg
+ in filtered_segments]).tolist()
+ if n_speakers is None:
+ n_speakers = len(speakers)
+ es = end * frame_shift if end is not None else None
+ data, rate = kaldi_obj.load_wav(rec, start * frame_shift, es)
+ n_frames = _count_frames(len(data), frame_size, frame_shift)
+ T = np.zeros((n_frames, n_speakers), dtype=np.int32)
+ if end is None:
+ end = n_frames
+
+ for seg in filtered_segments:
+ speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
+ start_frame = np.rint(
+ seg['st'] * rate / frame_shift).astype(int)
+ end_frame = np.rint(
+ seg['et'] * rate / frame_shift).astype(int)
+ rel_start = rel_end = None
+ if start <= start_frame and start_frame < end:
+ rel_start = start_frame - start
+ if start < end_frame and end_frame <= end:
+ rel_end = end_frame - start
+ if rel_start is not None or rel_end is not None:
+ T[rel_start:rel_end, speaker_index] = 1
+ return T
+
+
+def get_labeledSTFT(
+ kaldi_obj,
+ rec, start, end, frame_size, frame_shift,
+ n_speakers=None,
+ use_speaker_id=False):
+ """ Extracts STFT and corresponding labels
+
+ Extracts STFT and corresponding diarization labels for
+ given recording id and start/end times
+
+ Args:
+ kaldi_obj (KaldiData)
+ rec (str): recording id
+ start (int): start frame index
+ end (int): end frame index
+ frame_size (int): number of samples in a frame
+ frame_shift (int): number of shift samples
+ n_speakers (int): number of speakers
+ if None, the value is given from data
+ Returns:
+ Y: STFT
+ (n_frames, n_bins)-shaped np.complex64 array,
+ T: label
+ (n_frmaes, n_speakers)-shaped np.int32 array.
+ """
+ data, rate = kaldi_obj.load_wav(
+ rec, start * frame_shift, end * frame_shift)
+ Y = stft(data, frame_size, frame_shift)
+ filtered_segments = kaldi_obj.segments[rec]
+ # filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec]
+ speakers = np.unique(
+ [kaldi_obj.utt2spk[seg['utt']] for seg
+ in filtered_segments]).tolist()
+ if n_speakers is None:
+ n_speakers = len(speakers)
+ T = np.zeros((Y.shape[0], n_speakers), dtype=np.int32)
+
+ if use_speaker_id:
+ all_speakers = sorted(kaldi_obj.spk2utt.keys())
+ S = np.zeros((Y.shape[0], len(all_speakers)), dtype=np.int32)
+
+ for seg in filtered_segments:
+ speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']])
+ if use_speaker_id:
+ all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg['utt']])
+ start_frame = np.rint(
+ seg['st'] * rate / frame_shift).astype(int)
+ end_frame = np.rint(
+ seg['et'] * rate / frame_shift).astype(int)
+ rel_start = rel_end = None
+ if start <= start_frame and start_frame < end:
+ rel_start = start_frame - start
+ if start < end_frame and end_frame <= end:
+ rel_end = end_frame - start
+ if rel_start is not None or rel_end is not None:
+ T[rel_start:rel_end, speaker_index] = 1
+ if use_speaker_id:
+ S[rel_start:rel_end, all_speaker_index] = 1
+
+ if use_speaker_id:
+ return Y, T, S
+ else:
+ return Y, T
--
Gitblit v1.9.1