From e1ba6bc138b4e73875c64f35f98f3b15a0560e92 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 17 五月 2023 15:16:06 +0800
Subject: [PATCH] Merge branch 'dev_infer' of https://github.com/alibaba/FunASR into dev_infer
---
egs/aishell2/transformer/utils/combine_cmvn_file.py | 72 ++++++++++++++++++++++++++++++++++++
1 files changed, 72 insertions(+), 0 deletions(-)
diff --git a/egs/aishell2/transformer/utils/combine_cmvn_file.py b/egs/aishell2/transformer/utils/combine_cmvn_file.py
new file mode 100755
index 0000000..c525973
--- /dev/null
+++ b/egs/aishell2/transformer/utils/combine_cmvn_file.py
@@ -0,0 +1,72 @@
+import argparse
+import json
+import os
+
+import numpy as np
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="combine cmvn file",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--dim",
+ default=80,
+ type=int,
+ help="feature dim",
+ )
+ parser.add_argument(
+ "--cmvn_dir",
+ default=False,
+ required=True,
+ type=str,
+ help="cmvn dir",
+ )
+
+ parser.add_argument(
+ "--nj",
+ default=1,
+ required=True,
+ type=int,
+ help="num of cmvn files",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=False,
+ required=True,
+ type=str,
+ help="output dir",
+ )
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ total_means = np.zeros(args.dim)
+ total_vars = np.zeros(args.dim)
+ total_frames = 0
+
+ cmvn_file = os.path.join(args.output_dir, "cmvn.json")
+
+ for i in range(1, args.nj + 1):
+ with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
+ cmvn_stats = json.load(fin)
+
+ total_means += np.array(cmvn_stats["mean_stats"])
+ total_vars += np.array(cmvn_stats["var_stats"])
+ total_frames += cmvn_stats["total_frames"]
+
+ cmvn_info = {
+ 'mean_stats': list(total_means.tolist()),
+ 'var_stats': list(total_vars.tolist()),
+ 'total_frames': total_frames
+ }
+ with open(cmvn_file, 'w') as fout:
+ fout.write(json.dumps(cmvn_info))
+
+
+if __name__ == '__main__':
+ main()
--
Gitblit v1.9.1