zhifu gao
2023-05-18 97a689d65da434345a641a909f13b78e5690c86b
egs/aishell2/transformer/utils/combine_cmvn_file.py
@@ -1,6 +1,9 @@
import argparse
import json
import os
import numpy as np
def get_parser():
    parser = argparse.ArgumentParser(
@@ -8,15 +11,13 @@
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--dims",
        "-d",
        "--dim",
        default=80,
        type=int,
        help="feature dims",
        help="feature dim",
    )
    parser.add_argument(
        "--cmvn-dir",
        "-c",
        "--cmvn_dir",
        default=False,
        required=True,
        type=str,
@@ -25,15 +26,13 @@
    parser.add_argument(
        "--nj",
        "-n",
        default=1,
        required=True,
        type=int,
        help="num of cmvn file",
        help="num of cmvn files",
    )
    parser.add_argument(
        "--output-dir",
        "-o",
        "--output_dir",
        default=False,
        required=True,
        type=str,
@@ -46,14 +45,14 @@
    parser = get_parser()
    args = parser.parse_args()
    total_means = np.zeros(args.dims)
    total_vars = np.zeros(args.dims)
    total_means = np.zeros(args.dim)
    total_vars = np.zeros(args.dim)
    total_frames = 0
    cmvn_file = args.output_dir + "/cmvn.json"
    cmvn_file = os.path.join(args.output_dir, "cmvn.json")
    for i in range(1, args.nj+1):
        with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
    for i in range(1, args.nj + 1):
        with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
            cmvn_stats = json.load(fin)
        total_means += np.array(cmvn_stats["mean_stats"])