import argparse
|
import os
|
|
import numpy as np
|
|
import funasr.modules.eend_ola.utils.feature as feature
|
import funasr.modules.eend_ola.utils.kaldi_data as kaldi_data
|
|
|
def _count_frames(data_len, size, step):
|
return int((data_len - size + step) / step)
|
|
|
def _gen_frame_indices(
|
data_length, size=2000, step=2000,
|
use_last_samples=False,
|
label_delay=0,
|
subsampling=1):
|
i = -1
|
for i in range(_count_frames(data_length, size, step)):
|
yield i * step, i * step + size
|
if use_last_samples and i * step + size < data_length:
|
if data_length - (i + 1) * step - subsampling * label_delay > 0:
|
yield (i + 1) * step, data_length
|
|
|
class KaldiDiarizationDataset():
|
def __init__(
|
self,
|
data_dir,
|
chunk_size=2000,
|
context_size=0,
|
frame_size=1024,
|
frame_shift=256,
|
subsampling=1,
|
rate=16000,
|
input_transform=None,
|
use_last_samples=False,
|
label_delay=0,
|
n_speakers=None,
|
):
|
self.data_dir = data_dir
|
self.chunk_size = chunk_size
|
self.context_size = context_size
|
self.frame_size = frame_size
|
self.frame_shift = frame_shift
|
self.subsampling = subsampling
|
self.input_transform = input_transform
|
self.n_speakers = n_speakers
|
self.chunk_indices = []
|
self.label_delay = label_delay
|
|
self.data = kaldi_data.KaldiData(self.data_dir)
|
|
# make chunk indices: filepath, start_frame, end_frame
|
for rec, path in self.data.wavs.items():
|
data_len = int(self.data.reco2dur[rec] * rate / frame_shift)
|
data_len = int(data_len / self.subsampling)
|
for st, ed in _gen_frame_indices(
|
data_len, chunk_size, chunk_size, use_last_samples,
|
label_delay=self.label_delay,
|
subsampling=self.subsampling):
|
self.chunk_indices.append(
|
(rec, path, st * self.subsampling, ed * self.subsampling))
|
print(len(self.chunk_indices), " chunks")
|
|
|
def convert(args):
|
f = open(out_wav_file, 'w')
|
dataset = KaldiDiarizationDataset(
|
data_dir=args.data_dir,
|
chunk_size=args.num_frames,
|
context_size=args.context_size,
|
input_transform=args.input_transform,
|
frame_size=args.frame_size,
|
frame_shift=args.frame_shift,
|
subsampling=args.subsampling,
|
rate=8000,
|
use_last_samples=True,
|
)
|
length = len(dataset.chunk_indices)
|
for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices):
|
Y, T = feature.get_labeledSTFT(
|
dataset.data,
|
rec,
|
st,
|
ed,
|
dataset.frame_size,
|
dataset.frame_shift,
|
dataset.n_speakers)
|
Y = feature.transform(Y, dataset.input_transform)
|
Y_spliced = feature.splice(Y, dataset.context_size)
|
Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling)
|
st = '{:0>7d}'.format(st)
|
ed = '{:0>7d}'.format(ed)
|
suffix = '_' + st + '_' + ed
|
|
parts = os.readlink('/'.join(path.split('/')[:-1])).split('/')
|
# print('parts: ', parts)
|
parts = parts[:4] + ['numpy_data'] + parts[4:]
|
cur_path = '/'.join(parts)
|
# print('cur path: ', cur_path)
|
out_path = os.path.join(cur_path, path.split('/')[-1].split('.')[0] + suffix + '.npz')
|
# print(out_path)
|
# print(cur_path)
|
if not os.path.exists(cur_path):
|
os.makedirs(cur_path)
|
np.savez(out_path, Y=Y_ss, T=T_ss)
|
if idx == length - 1:
|
f.write(rec + suffix + ' ' + out_path)
|
else:
|
f.write(rec + suffix + ' ' + out_path + '\n')
|
|
|
if __name__ == '__main__':
|
parser = argparse.ArgumentParser()
|
parser.add_argument("data_dir")
|
parser.add_argument("num_frames")
|
parser.add_argument("context_size")
|
parser.add_argument("frame_size")
|
parser.add_argument("frame_shift")
|
parser.add_argument("subsampling")
|
|
|
|
args = parser.parse_args()
|
convert(args)
|