From e1ba6bc138b4e73875c64f35f98f3b15a0560e92 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 17 五月 2023 15:16:06 +0800
Subject: [PATCH] Merge branch 'dev_infer' of https://github.com/alibaba/FunASR into dev_infer

---
 egs/aishell2/transformer/utils/combine_cmvn_file.py |   72 ++++++++++++++++++++++++++++++++++++
 1 files changed, 72 insertions(+), 0 deletions(-)

diff --git a/egs/aishell2/transformer/utils/combine_cmvn_file.py b/egs/aishell2/transformer/utils/combine_cmvn_file.py
new file mode 100755
index 0000000..c525973
--- /dev/null
+++ b/egs/aishell2/transformer/utils/combine_cmvn_file.py
@@ -0,0 +1,72 @@
+import argparse
+import json
+import os
+
+import numpy as np
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        description="combine cmvn file",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--dim",
+        default=80,
+        type=int,
+        help="feature dim",
+    )
+    parser.add_argument(
+        "--cmvn_dir",
+        default=False,
+        required=True,
+        type=str,
+        help="cmvn dir",
+    )
+
+    parser.add_argument(
+        "--nj",
+        default=1,
+        required=True,
+        type=int,
+        help="num of cmvn files",
+    )
+    parser.add_argument(
+        "--output_dir",
+        default=False,
+        required=True,
+        type=str,
+        help="output dir",
+    )
+    return parser
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    total_means = np.zeros(args.dim)
+    total_vars = np.zeros(args.dim)
+    total_frames = 0
+
+    cmvn_file = os.path.join(args.output_dir, "cmvn.json")
+
+    for i in range(1, args.nj + 1):
+        with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
+            cmvn_stats = json.load(fin)
+
+        total_means += np.array(cmvn_stats["mean_stats"])
+        total_vars += np.array(cmvn_stats["var_stats"])
+        total_frames += cmvn_stats["total_frames"]
+
+    cmvn_info = {
+        'mean_stats': list(total_means.tolist()),
+        'var_stats': list(total_vars.tolist()),
+        'total_frames': total_frames
+    }
+    with open(cmvn_file, 'w') as fout:
+        fout.write(json.dumps(cmvn_info))
+
+
+if __name__ == '__main__':
+    main()

--
Gitblit v1.9.1