| | |
| | | import argparse |
| | | import os |
| | | |
| | | import numpy as np |
| | | from kaldiio import WriteHelper |
| | | |
| | | import funasr.modules.eend_ola.utils.feature as feature |
| | | import funasr.modules.eend_ola.utils.kaldi_data as kaldi_data |
| | | from funasr.modules.eend_ola.utils.kaldi_data import load_segments_rechash, load_utt2spk, load_wav_scp, load_reco2dur, \ |
| | | load_spk2utt, load_wav |
| | | |
| | | |
| | | def _count_frames(data_len, size, step): |
| | |
| | | yield (i + 1) * step, data_length |
| | | |
| | | |
| | | class KaldiData: |
| | | def __init__(self, data_dir, idx): |
| | | self.data_dir = data_dir |
| | | segment_file = os.path.join(self.data_dir, 'segments.{}'.format(idx)) |
| | | self.segments = load_segments_rechash(segment_file) |
| | | |
| | | utt2spk_file = os.path.join(self.data_dir, 'utt2spk.{}'.format(idx)) |
| | | self.utt2spk = load_utt2spk(utt2spk_file) |
| | | |
| | | wav_file = os.path.join(self.data_dir, 'wav.scp.{}'.format(idx)) |
| | | self.wavs = load_wav_scp(wav_file) |
| | | |
| | | reco2dur_file = os.path.join(self.data_dir, 'reco2dur.{}'.format(idx)) |
| | | self.reco2dur = load_reco2dur(reco2dur_file) |
| | | |
| | | spk2utt_file = os.path.join(self.data_dir, 'spk2utt.{}'.format(idx)) |
| | | self.spk2utt = load_spk2utt(spk2utt_file) |
| | | |
| | | def load_wav(self, recid, start=0, end=None): |
| | | data, rate = load_wav(self.wavs[recid], start, end) |
| | | return data, rate |
| | | |
| | | |
| | | class KaldiDiarizationDataset(): |
| | | def __init__( |
| | | self, |
| | | data_dir, |
| | | index, |
| | | chunk_size=2000, |
| | | context_size=0, |
| | | frame_size=1024, |
| | |
| | | n_speakers=None, |
| | | ): |
| | | self.data_dir = data_dir |
| | | self.index = index |
| | | self.chunk_size = chunk_size |
| | | self.context_size = context_size |
| | | self.frame_size = frame_size |
| | |
| | | self.chunk_indices = [] |
| | | self.label_delay = label_delay |
| | | |
| | | self.data = kaldi_data.KaldiData(self.data_dir) |
| | | self.data = KaldiData(self.data_dir, index) |
| | | |
| | | # make chunk indices: filepath, start_frame, end_frame |
| | | for rec, path in self.data.wavs.items(): |
| | | data_len = int(self.data.reco2dur[rec] * rate / frame_shift) |
| | | data_len = int(data_len / self.subsampling) |
| | |
| | | |
| | | |
| | | def convert(args): |
| | | f = open(out_wav_file, 'w') |
| | | dataset = KaldiDiarizationDataset( |
| | | data_dir=args.data_dir, |
| | | index=args.index, |
| | | chunk_size=args.num_frames, |
| | | context_size=args.context_size, |
| | | input_transform=args.input_transform, |
| | | input_transform="logmel23_mn", |
| | | frame_size=args.frame_size, |
| | | frame_shift=args.frame_shift, |
| | | subsampling=args.subsampling, |
| | | rate=8000, |
| | | use_last_samples=True, |
| | | ) |
| | | length = len(dataset.chunk_indices) |
| | | for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices): |
| | | Y, T = feature.get_labeledSTFT( |
| | | dataset.data, |
| | | rec, |
| | | st, |
| | | ed, |
| | | dataset.frame_size, |
| | | dataset.frame_shift, |
| | | dataset.n_speakers) |
| | | Y = feature.transform(Y, dataset.input_transform) |
| | | Y_spliced = feature.splice(Y, dataset.context_size) |
| | | Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling) |
| | | st = '{:0>7d}'.format(st) |
| | | ed = '{:0>7d}'.format(ed) |
| | | suffix = '_' + st + '_' + ed |
| | | |
| | | parts = os.readlink('/'.join(path.split('/')[:-1])).split('/') |
| | | # print('parts: ', parts) |
| | | parts = parts[:4] + ['numpy_data'] + parts[4:] |
| | | cur_path = '/'.join(parts) |
| | | # print('cur path: ', cur_path) |
| | | out_path = os.path.join(cur_path, path.split('/')[-1].split('.')[0] + suffix + '.npz') |
| | | # print(out_path) |
| | | # print(cur_path) |
| | | if not os.path.exists(cur_path): |
| | | os.makedirs(cur_path) |
| | | np.savez(out_path, Y=Y_ss, T=T_ss) |
| | | if idx == length - 1: |
| | | f.write(rec + suffix + ' ' + out_path) |
| | | else: |
| | | f.write(rec + suffix + ' ' + out_path + '\n') |
| | | feature_ark_file = os.path.join(args.output_dir, "feature.ark.{}".format(args.index)) |
| | | feature_scp_file = os.path.join(args.output_dir, "feature.scp.{}".format(args.index)) |
| | | label_ark_file = os.path.join(args.output_dir, "label.ark.{}".format(args.index)) |
| | | label_scp_file = os.path.join(args.output_dir, "label.scp.{}".format(args.index)) |
| | | with WriteHelper('ark,scp:{},{}'.format(feature_ark_file, feature_scp_file)) as feature_writer, \ |
| | | WriteHelper('ark,scp:{},{}'.format(label_ark_file, label_scp_file)) as label_writer: |
| | | for idx, (rec, path, st, ed) in enumerate(dataset.chunk_indices): |
| | | Y, T = feature.get_labeledSTFT( |
| | | dataset.data, |
| | | rec, |
| | | st, |
| | | ed, |
| | | dataset.frame_size, |
| | | dataset.frame_shift, |
| | | dataset.n_speakers) |
| | | Y = feature.transform(Y, dataset.input_transform) |
| | | Y_spliced = feature.splice(Y, dataset.context_size) |
| | | Y_ss, T_ss = feature.subsample(Y_spliced, T, dataset.subsampling) |
| | | st = '{:0>7d}'.format(st) |
| | | ed = '{:0>7d}'.format(ed) |
| | | key = "{}_{}_{}".format(rec, st, ed) |
| | | feature_writer(key, Y_ss) |
| | | label_writer(key, T_ss.reshape(-1)) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | parser = argparse.ArgumentParser() |
| | | parser.add_argument("data_dir") |
| | | parser.add_argument("num_frames") |
| | | parser.add_argument("context_size") |
| | | parser.add_argument("frame_size") |
| | | parser.add_argument("frame_shift") |
| | | parser.add_argument("subsampling") |
| | | |
| | | |
| | | parser.add_argument("output_dir") |
| | | parser.add_argument("index") |
| | | parser.add_argument("num_frames", default=500) |
| | | parser.add_argument("context_size", default=7) |
| | | parser.add_argument("frame_size", default=200) |
| | | parser.add_argument("frame_shift", default=80) |
| | | parser.add_argument("subsampling", default=10) |
| | | |
| | | args = parser.parse_args() |
| | | convert(args) |