from kaldiio import WriteHelper import argparse import numpy as np import json import torch import torchaudio import torchaudio.compliance.kaldi as kaldi def compute_fbank(wav_file, num_mel_bins=80, frame_length=25, frame_shift=10, dither=0.0, resample_rate=16000, speed=1.0): waveform, sample_rate = torchaudio.load(wav_file) if resample_rate != sample_rate: waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform) if speed != 1.0: waveform, _ = torchaudio.sox_effects.apply_effects_tensor( waveform, resample_rate, [['speed', str(speed)], ['rate', str(resample_rate)]] ) waveform = waveform * (1 << 15) mat = kaldi.fbank(waveform, num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither, energy_floor=0.0, window_type='hamming', sample_frequency=resample_rate) return mat.numpy() def get_parser(): parser = argparse.ArgumentParser( description="computer features", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--wav-lists", "-w", default=False, required=True, type=str, help="input wav lists", ) parser.add_argument( "--text-files", "-t", default=False, required=True, type=str, help="input text files", ) parser.add_argument( "--dims", "-d", default=80, type=int, help="feature dims", ) parser.add_argument( "--sample-frequency", "-s", default=16000, type=int, help="sample frequency", ) parser.add_argument( "--speed-perturb", "-p", default="1.0", type=str, help="speed perturb", ) parser.add_argument( "--ark-index", "-a", default=1, required=True, type=int, help="ark index", ) parser.add_argument( "--output-dir", "-o", default=False, required=True, type=str, help="output dir", ) return parser def main(): parser = get_parser() args = parser.parse_args() ark_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".ark" scp_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".scp" text_file = args.output_dir + "/txt/text." + str(args.ark_index) + ".txt" feats_shape_file = args.output_dir + "/ark/len." + str(args.ark_index) text_shape_file = args.output_dir + "/txt/len." + str(args.ark_index) ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file)) text_writer = open(text_file, 'w') feats_shape_writer = open(feats_shape_file, 'w') text_shape_writer = open(text_shape_file, 'w') speed_perturb_list = args.speed_perturb.split(',') for speed in speed_perturb_list: with open(args.wav_lists, 'r', encoding='utf-8') as wavfile: with open(args.text_files, 'r', encoding='utf-8') as textfile: for wav, text in zip(wavfile, textfile): s_w = wav.strip().split() wav_id = s_w[0] wav_file = s_w[1] s_t = text.strip().split() text_id = s_t[0] txt = s_t[1:] fbank = compute_fbank(wav_file, num_mel_bins=args.dims, resample_rate=args.sample_frequency, speed=float(speed) ) feats_dims = fbank.shape[1] feats_lens = fbank.shape[0] txt_lens = len(txt) if speed == "1.0": wav_id_sp = wav_id else: wav_id_sp = wav_id + "_sp" + speed feats_shape_writer.write(wav_id_sp + " " + str(feats_lens) + "," + str(feats_dims) + '\n') text_shape_writer.write(wav_id_sp + " " + str(txt_lens) + '\n') text_writer.write(wav_id_sp + " " + " ".join(txt) + '\n') ark_writer(wav_id_sp, fbank) if __name__ == '__main__': main()