From e81eef957f624cbcbbcbdf682fcdd456671d0684 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 25 四月 2023 19:12:52 +0800
Subject: [PATCH] update
---
egs/aishell/transformer/utils/compute_cmvn.sh | 9 ++--
egs/aishell/transformer/utils/compute_cmvn.py | 84 ++++++++++++++++++++++++++++-------------
2 files changed, 61 insertions(+), 32 deletions(-)
diff --git a/egs/aishell/transformer/utils/compute_cmvn.py b/egs/aishell/transformer/utils/compute_cmvn.py
index 2b96e26..c57239a 100755
--- a/egs/aishell/transformer/utils/compute_cmvn.py
+++ b/egs/aishell/transformer/utils/compute_cmvn.py
@@ -1,8 +1,10 @@
-from kaldiio import ReadHelper
-
import argparse
-import numpy as np
import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
def get_parser():
@@ -11,55 +13,83 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--dims",
- "-d",
+ "--dim",
default=80,
type=int,
- help="feature dims",
+ help="feature dimension",
)
parser.add_argument(
- "--ark-file",
- "-a",
+ "--wav_path",
default=False,
required=True,
type=str,
- help="fbank ark file",
+ help="the path of wav scps",
)
parser.add_argument(
- "--ark-index",
- "-i",
+ "--idx",
default=1,
required=True,
type=int,
- help="ark index",
- )
- parser.add_argument(
- "--output-dir",
- "-o",
- default=False,
- required=True,
- type=str,
- help="output dir",
+ help="index",
)
return parser
+
+
+def compute_fbank(wav_file,
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0,
+ resample_rate=16000,
+ speed=1.0,
+ window_type="hamming"):
+ waveform, sample_rate = torchaudio.load(wav_file)
+ if resample_rate != sample_rate:
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+ new_freq=resample_rate)(waveform)
+ if speed != 1.0:
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, resample_rate,
+ [['speed', str(speed)], ['rate', str(resample_rate)]]
+ )
+
+ waveform = waveform * (1 << 15)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ window_type=window_type,
+ sample_frequency=resample_rate)
+
+ return mat.numpy()
def main():
parser = get_parser()
args = parser.parse_args()
- ark_file = args.ark_file + "/feats." + str(args.ark_index) + ".ark"
- cmvn_file = args.output_dir + "/cmvn." + str(args.ark_index) + ".json"
+ wav_scp_file = os.path.join(args.wav_path + "{}/wav.scp".format(args.idx))
+ cmvn_file = os.path.join(args.wav_path + "{}/cmvn.json".format(args.idx))
mean_stats = np.zeros(args.dims)
var_stats = np.zeros(args.dims)
total_frames = 0
- with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
- for key, mat in ark_reader:
- mean_stats += np.sum(mat, axis=0)
- var_stats += np.sum(np.square(mat), axis=0)
- total_frames += mat.shape[0]
+ # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+ # for key, mat in ark_reader:
+ # mean_stats += np.sum(mat, axis=0)
+ # var_stats += np.sum(np.square(mat), axis=0)
+ # total_frames += mat.shape[0]
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ _, wav_file = line.strip().split()
+ fbank = compute_fbank(wav_file, num_mel_bins=args.dims)
+ mean_stats += np.sum(fbank, axis=0)
+ var_stats += np.sum(np.square(fbank), axis=0)
+ total_frames += fbank.shape[0]
cmvn_info = {
'mean_stats': list(mean_stats.tolist()),
diff --git a/egs/aishell/transformer/utils/compute_cmvn.sh b/egs/aishell/transformer/utils/compute_cmvn.sh
index b6443cf..5809b77 100755
--- a/egs/aishell/transformer/utils/compute_cmvn.sh
+++ b/egs/aishell/transformer/utils/compute_cmvn.sh
@@ -23,9 +23,8 @@
output_dir=${fbankdir}/cmvn
logdir=${fbankdir}/cmvn/log
$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
- python utils/compute_cmvn.py -d ${feats_dim} -a $fbankdir/ark -i JOB -o ${output_dir} \
- || exit 1;
+ python utils/compute_cmvn.py -dim ${feats_dim} -wav_path $split_dir -idx JOB
-python utils/combine_cmvn_file.py -d ${feats_dim} -c ${output_dir} -n $nj -o $fbankdir
-
-echo "$0: Succeeded compute global cmvn"
+#python utils/combine_cmvn_file.py -d ${feats_dim} -c ${output_dir} -n $nj -o $fbankdir
+#
+#echo "$0: Succeeded compute global cmvn"
--
Gitblit v1.9.1