From 559cc2c6e296bc80917a7408911f671dfcc2b68b Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 12 五月 2023 17:25:54 +0800
Subject: [PATCH] update repo

---
 egs/aishell2/transformer/utils/apply_cmvn.py |   79 +++++++++++++++++++++++++++++++++++++++
 1 files changed, 79 insertions(+), 0 deletions(-)

diff --git a/egs/aishell2/transformer/utils/apply_cmvn.py b/egs/aishell2/transformer/utils/apply_cmvn.py
new file mode 100755
index 0000000..b5c5086
--- /dev/null
+++ b/egs/aishell2/transformer/utils/apply_cmvn.py
@@ -0,0 +1,79 @@
+from kaldiio import ReadHelper
+from kaldiio import WriteHelper
+
+import argparse
+import json
+import math
+import numpy as np
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        description="apply cmvn",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--ark-file",
+        "-a",
+        default=False,
+        required=True,
+        type=str,
+        help="fbank ark file",
+    )
+    parser.add_argument(
+        "--cmvn-file",
+        "-c",
+        default=False,
+        required=True,
+        type=str,
+        help="cmvn file",
+    )
+    parser.add_argument(
+        "--ark-index",
+        "-i",
+        default=1,
+        required=True,
+        type=int,
+        help="ark index",
+    )
+    parser.add_argument(
+        "--output-dir",
+        "-o",
+        default=False,
+        required=True,
+        type=str,
+        help="output dir",
+    )
+    return parser
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark"
+    scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp"
+    ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
+
+    with open(args.cmvn_file) as f:
+        cmvn_stats = json.load(f)
+
+    means = cmvn_stats['mean_stats']
+    vars = cmvn_stats['var_stats']
+    total_frames = cmvn_stats['total_frames']
+
+    for i in range(len(means)):
+        means[i] /= total_frames
+        vars[i] = vars[i] / total_frames - means[i] * means[i]
+        if vars[i] < 1.0e-20:
+            vars[i] = 1.0e-20
+        vars[i] = 1.0 / math.sqrt(vars[i])
+
+    with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader:
+        for key, mat in ark_reader:
+            mat = (mat - means) * vars
+            ark_writer(key, mat)
+
+
+if __name__ == '__main__':
+    main()

--
Gitblit v1.9.1