From 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 16 五月 2023 15:07:20 +0800
Subject: [PATCH] Merge pull request #518 from alibaba-damo-academy/dev_wjm2

---
 egs/aishell/transformer/utils/combine_cmvn_file.py |   27 +++++++++++++--------------
 1 files changed, 13 insertions(+), 14 deletions(-)

diff --git a/egs/aishell/transformer/utils/combine_cmvn_file.py b/egs/aishell/transformer/utils/combine_cmvn_file.py
index b2974a4..c525973 100755
--- a/egs/aishell/transformer/utils/combine_cmvn_file.py
+++ b/egs/aishell/transformer/utils/combine_cmvn_file.py
@@ -1,6 +1,9 @@
 import argparse
 import json
+import os
+
 import numpy as np
+
 
 def get_parser():
     parser = argparse.ArgumentParser(
@@ -8,15 +11,13 @@
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
     parser.add_argument(
-        "--dims",
-        "-d",
+        "--dim",
         default=80,
         type=int,
-        help="feature dims",
+        help="feature dim",
     )
     parser.add_argument(
-        "--cmvn-dir",
-        "-c",
+        "--cmvn_dir",
         default=False,
         required=True,
         type=str,
@@ -25,15 +26,13 @@
 
     parser.add_argument(
         "--nj",
-        "-n",
         default=1,
         required=True,
         type=int,
-        help="num of cmvn file",
+        help="num of cmvn files",
     )
     parser.add_argument(
-        "--output-dir",
-        "-o",
+        "--output_dir",
         default=False,
         required=True,
         type=str,
@@ -46,14 +45,14 @@
     parser = get_parser()
     args = parser.parse_args()
 
-    total_means = np.zeros(args.dims)
-    total_vars = np.zeros(args.dims)
+    total_means = np.zeros(args.dim)
+    total_vars = np.zeros(args.dim)
     total_frames = 0
 
-    cmvn_file = args.output_dir + "/cmvn.json"
+    cmvn_file = os.path.join(args.output_dir, "cmvn.json")
 
-    for i in range(1, args.nj+1):
-        with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
+    for i in range(1, args.nj + 1):
+        with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
             cmvn_stats = json.load(fin)
 
         total_means += np.array(cmvn_stats["mean_stats"])

--
Gitblit v1.9.1