From d2a64f2137ac23d1951fd2fa25b6053bba6f7873 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 15 五月 2023 11:23:33 +0800
Subject: [PATCH] update repo

---
 egs/aishell2/transformer/utils/apply_lfr_and_cmvn.py |  143 +++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 143 insertions(+), 0 deletions(-)

diff --git a/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.py b/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.py
new file mode 100755
index 0000000..50d18d1
--- /dev/null
+++ b/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.py
@@ -0,0 +1,143 @@
+from kaldiio import ReadHelper, WriteHelper
+
+import argparse
+import numpy as np
+
+
+def build_LFR_features(inputs, m=7, n=6):
+    LFR_inputs = []
+    T = inputs.shape[0]
+    T_lfr = int(np.ceil(T / n))
+    left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
+    inputs = np.vstack((left_padding, inputs))
+    T = T + (m - 1) // 2
+    for i in range(T_lfr):
+        if m <= T - i * n:
+            LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
+        else:
+            num_padding = m - (T - i * n)
+            frame = np.hstack(inputs[i * n:])
+            for _ in range(num_padding):
+                frame = np.hstack((frame, inputs[-1]))
+            LFR_inputs.append(frame)
+    return np.vstack(LFR_inputs)
+
+
+def build_CMVN_features(inputs, mvn_file):  # noqa
+    with open(mvn_file, 'r', encoding='utf-8') as f:
+        lines = f.readlines()
+
+    add_shift_list = []
+    rescale_list = []
+    for i in range(len(lines)):
+        line_item = lines[i].split()
+        if line_item[0] == '<AddShift>':
+            line_item = lines[i + 1].split()
+            if line_item[0] == '<LearnRateCoef>':
+                add_shift_line = line_item[3:(len(line_item) - 1)]
+                add_shift_list = list(add_shift_line)
+                continue
+        elif line_item[0] == '<Rescale>':
+            line_item = lines[i + 1].split()
+            if line_item[0] == '<LearnRateCoef>':
+                rescale_line = line_item[3:(len(line_item) - 1)]
+                rescale_list = list(rescale_line)
+                continue
+
+    for j in range(inputs.shape[0]):
+        for k in range(inputs.shape[1]):
+            add_shift_value = add_shift_list[k]
+            rescale_value = rescale_list[k]
+            inputs[j, k] = float(inputs[j, k]) + float(add_shift_value)
+            inputs[j, k] = float(inputs[j, k]) * float(rescale_value)
+
+    return inputs
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        description="apply low_frame_rate and cmvn",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--ark-file",
+        "-a",
+        default=False,
+        required=True,
+        type=str,
+        help="fbank ark file",
+    )
+    parser.add_argument(
+        "--lfr",
+        "-f",
+        default=True,
+        type=str,
+        help="low frame rate",
+    )
+    parser.add_argument(
+        "--lfr-m",
+        "-m",
+        default=7,
+        type=int,
+        help="number of frames to stack",
+    )
+    parser.add_argument(
+        "--lfr-n",
+        "-n",
+        default=6,
+        type=int,
+        help="number of frames to skip",
+    )
+    parser.add_argument(
+        "--cmvn-file",
+        "-c",
+        default=False,
+        required=True,
+        type=str,
+        help="global cmvn file",
+    )
+    parser.add_argument(
+        "--ark-index",
+        "-i",
+        default=1,
+        required=True,
+        type=int,
+        help="ark index",
+    )
+    parser.add_argument(
+        "--output-dir",
+        "-o",
+        default=False,
+        required=True,
+        type=str,
+        help="output dir",
+    )
+    return parser
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    dump_ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark"
+    dump_scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp"
+    shape_file = args.output_dir + "/len." + str(args.ark_index)
+    ark_writer = WriteHelper('ark,scp:{},{}'.format(dump_ark_file, dump_scp_file))
+
+    shape_writer = open(shape_file, 'w')
+    with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader:
+        for key, mat in ark_reader:
+            if args.lfr:
+                lfr = build_LFR_features(mat, args.lfr_m, args.lfr_n)
+            else:
+                lfr = mat
+            cmvn = build_CMVN_features(lfr, args.cmvn_file)
+            dims = cmvn.shape[1]
+            lens = cmvn.shape[0]
+            shape_writer.write(key + " " + str(lens) + "," + str(dims) + '\n')
+            ark_writer(key, cmvn)
+
+
+if __name__ == '__main__':
+    main()
+

--
Gitblit v1.9.1