From 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 16 五月 2023 15:07:20 +0800
Subject: [PATCH] Merge pull request #518 from alibaba-damo-academy/dev_wjm2
---
egs/aishell/transformer/utils/combine_cmvn_file.py | 27 +++++++++++++--------------
1 files changed, 13 insertions(+), 14 deletions(-)
diff --git a/egs/aishell/transformer/utils/combine_cmvn_file.py b/egs/aishell/transformer/utils/combine_cmvn_file.py
index b2974a4..c525973 100755
--- a/egs/aishell/transformer/utils/combine_cmvn_file.py
+++ b/egs/aishell/transformer/utils/combine_cmvn_file.py
@@ -1,6 +1,9 @@
import argparse
import json
+import os
+
import numpy as np
+
def get_parser():
parser = argparse.ArgumentParser(
@@ -8,15 +11,13 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--dims",
- "-d",
+ "--dim",
default=80,
type=int,
- help="feature dims",
+ help="feature dim",
)
parser.add_argument(
- "--cmvn-dir",
- "-c",
+ "--cmvn_dir",
default=False,
required=True,
type=str,
@@ -25,15 +26,13 @@
parser.add_argument(
"--nj",
- "-n",
default=1,
required=True,
type=int,
- help="num of cmvn file",
+ help="num of cmvn files",
)
parser.add_argument(
- "--output-dir",
- "-o",
+ "--output_dir",
default=False,
required=True,
type=str,
@@ -46,14 +45,14 @@
parser = get_parser()
args = parser.parse_args()
- total_means = np.zeros(args.dims)
- total_vars = np.zeros(args.dims)
+ total_means = np.zeros(args.dim)
+ total_vars = np.zeros(args.dim)
total_frames = 0
- cmvn_file = args.output_dir + "/cmvn.json"
+ cmvn_file = os.path.join(args.output_dir, "cmvn.json")
- for i in range(1, args.nj+1):
- with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
+ for i in range(1, args.nj + 1):
+ with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
cmvn_stats = json.load(fin)
total_means += np.array(cmvn_stats["mean_stats"])
--
Gitblit v1.9.1