From 559cc2c6e296bc80917a7408911f671dfcc2b68b Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 12 五月 2023 17:25:54 +0800
Subject: [PATCH] update repo
---
egs/aishell2/transformer/utils/apply_lfr_and_cmvn.sh | 38
egs/aishell2/transformer/utils/subset_data_dir_tr_cv.sh | 30
egs/aishell2/transformer/utils/fix_data.sh | 35
egs/aishell2/transformer/utils/compute_cmvn.sh | 34
egs/aishell2/transformer/utils/compute_cmvn.py | 104 +
egs/aishell2/transformer/utils/error_rate_zh | 370 ++++++
egs/aishell2/transformer/utils/gen_ark_list.sh | 22
egs/aishell2/transformer/utils/apply_cmvn.sh | 29
egs/aishell2/transformer/utils/filter_scp.pl | 87 +
egs/aishell2/transformer/utils/apply_cmvn.py | 79 +
egs/aishell2/transformer/utils/compute_fbank.sh | 54
egs/aishell2/transformer/utils/compute_wer.py | 157 ++
egs/aishell2/transformer/utils/fix_data_feat.sh | 52
egs/aishell2/transformer/utils/text2token.py | 135 ++
egs/aishell2/transformer/utils/download_model.py | 20
egs/aishell2/transformer/utils/textnorm_zh.py | 834 +++++++++++++
egs/aishell2/transformer/utils/shuffle_list.pl | 44
egs/aishell2/transformer/utils/print_args.py | 45
egs/aishell2/transformer/utils/parse_options.sh | 97 +
egs/aishell2/transformer/utils/__init__.py | 0
egs/aishell2/transformer/utils/combine_cmvn_file.py | 72 +
egs/aishell2/transformer/utils/proce_text.py | 31
egs/aishell2/transformer/utils/extract_embeds.py | 47
egs/aishell2/transformer/utils/cmvn_converter.py | 51
egs/aishell2/transformer/utils/split_scp.pl | 246 ++++
egs/aishell2/transformer/utils/compute_fbank.py | 171 ++
egs/aishell2/transformer/utils/text_tokenize.py | 106 +
/dev/null | 1
egs/aishell2/transformer/utils/text_tokenize.sh | 35
egs/aishell2/transformer/utils/run.pl | 356 +++++
egs/aishell2/transformer/utils/apply_lfr_and_cmvn.py | 143 ++
egs/aishell2/transformer/utils/proc_conf_oss.py | 35
egs/aishell2/transformer/utils/split_data.py | 60
33 files changed, 3,619 insertions(+), 1 deletions(-)
diff --git a/egs/aishell2/transformer/utils b/egs/aishell2/transformer/utils
deleted file mode 120000
index fe070dd..0000000
--- a/egs/aishell2/transformer/utils
+++ /dev/null
@@ -1 +0,0 @@
-../../aishell/transformer/utils
\ No newline at end of file
diff --git a/egs/aishell2/transformer/utils/__init__.py b/egs/aishell2/transformer/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/egs/aishell2/transformer/utils/__init__.py
diff --git a/egs/aishell2/transformer/utils/apply_cmvn.py b/egs/aishell2/transformer/utils/apply_cmvn.py
new file mode 100755
index 0000000..b5c5086
--- /dev/null
+++ b/egs/aishell2/transformer/utils/apply_cmvn.py
@@ -0,0 +1,79 @@
+from kaldiio import ReadHelper
+from kaldiio import WriteHelper
+
+import argparse
+import json
+import math
+import numpy as np
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="apply cmvn",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--ark-file",
+ "-a",
+ default=False,
+ required=True,
+ type=str,
+ help="fbank ark file",
+ )
+ parser.add_argument(
+ "--cmvn-file",
+ "-c",
+ default=False,
+ required=True,
+ type=str,
+ help="cmvn file",
+ )
+ parser.add_argument(
+ "--ark-index",
+ "-i",
+ default=1,
+ required=True,
+ type=int,
+ help="ark index",
+ )
+ parser.add_argument(
+ "--output-dir",
+ "-o",
+ default=False,
+ required=True,
+ type=str,
+ help="output dir",
+ )
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark"
+ scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp"
+ ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
+
+ with open(args.cmvn_file) as f:
+ cmvn_stats = json.load(f)
+
+ means = cmvn_stats['mean_stats']
+ vars = cmvn_stats['var_stats']
+ total_frames = cmvn_stats['total_frames']
+
+ for i in range(len(means)):
+ means[i] /= total_frames
+ vars[i] = vars[i] / total_frames - means[i] * means[i]
+ if vars[i] < 1.0e-20:
+ vars[i] = 1.0e-20
+ vars[i] = 1.0 / math.sqrt(vars[i])
+
+ with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader:
+ for key, mat in ark_reader:
+ mat = (mat - means) * vars
+ ark_writer(key, mat)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/egs/aishell2/transformer/utils/apply_cmvn.sh b/egs/aishell2/transformer/utils/apply_cmvn.sh
new file mode 100755
index 0000000..f8fd1d1
--- /dev/null
+++ b/egs/aishell2/transformer/utils/apply_cmvn.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+# Begin configuration section.
+nj=32
+cmd=./utils/run.pl
+
+echo "$0 $@"
+
+. utils/parse_options.sh || exit 1;
+
+fbankdir=$1
+cmvn_file=$2
+logdir=$3
+output_dir=$4
+
+dump_dir=${output_dir}/ark; mkdir -p ${dump_dir}
+mkdir -p ${logdir}
+
+$cmd JOB=1:$nj $logdir/apply_cmvn.JOB.log \
+ python utils/apply_cmvn.py -a $fbankdir/ark/feats.JOB.ark \
+ -c $cmvn_file -i JOB -o ${dump_dir} \
+ || exit 1;
+
+for n in $(seq $nj); do
+ cat ${dump_dir}/feats.$n.scp || exit 1
+done > ${output_dir}/feats.scp || exit 1
+
+echo "$0: Succeeded apply cmvn"
diff --git a/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.py b/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.py
new file mode 100755
index 0000000..50d18d1
--- /dev/null
+++ b/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.py
@@ -0,0 +1,143 @@
+from kaldiio import ReadHelper, WriteHelper
+
+import argparse
+import numpy as np
+
+
+def build_LFR_features(inputs, m=7, n=6):
+ LFR_inputs = []
+ T = inputs.shape[0]
+ T_lfr = int(np.ceil(T / n))
+ left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
+ inputs = np.vstack((left_padding, inputs))
+ T = T + (m - 1) // 2
+ for i in range(T_lfr):
+ if m <= T - i * n:
+ LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
+ else:
+ num_padding = m - (T - i * n)
+ frame = np.hstack(inputs[i * n:])
+ for _ in range(num_padding):
+ frame = np.hstack((frame, inputs[-1]))
+ LFR_inputs.append(frame)
+ return np.vstack(LFR_inputs)
+
+
+def build_CMVN_features(inputs, mvn_file): # noqa
+ with open(mvn_file, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+
+ add_shift_list = []
+ rescale_list = []
+ for i in range(len(lines)):
+ line_item = lines[i].split()
+ if line_item[0] == '<AddShift>':
+ line_item = lines[i + 1].split()
+ if line_item[0] == '<LearnRateCoef>':
+ add_shift_line = line_item[3:(len(line_item) - 1)]
+ add_shift_list = list(add_shift_line)
+ continue
+ elif line_item[0] == '<Rescale>':
+ line_item = lines[i + 1].split()
+ if line_item[0] == '<LearnRateCoef>':
+ rescale_line = line_item[3:(len(line_item) - 1)]
+ rescale_list = list(rescale_line)
+ continue
+
+ for j in range(inputs.shape[0]):
+ for k in range(inputs.shape[1]):
+ add_shift_value = add_shift_list[k]
+ rescale_value = rescale_list[k]
+ inputs[j, k] = float(inputs[j, k]) + float(add_shift_value)
+ inputs[j, k] = float(inputs[j, k]) * float(rescale_value)
+
+ return inputs
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="apply low_frame_rate and cmvn",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--ark-file",
+ "-a",
+ default=False,
+ required=True,
+ type=str,
+ help="fbank ark file",
+ )
+ parser.add_argument(
+ "--lfr",
+ "-f",
+ default=True,
+ type=str,
+ help="low frame rate",
+ )
+ parser.add_argument(
+ "--lfr-m",
+ "-m",
+ default=7,
+ type=int,
+ help="number of frames to stack",
+ )
+ parser.add_argument(
+ "--lfr-n",
+ "-n",
+ default=6,
+ type=int,
+ help="number of frames to skip",
+ )
+ parser.add_argument(
+ "--cmvn-file",
+ "-c",
+ default=False,
+ required=True,
+ type=str,
+ help="global cmvn file",
+ )
+ parser.add_argument(
+ "--ark-index",
+ "-i",
+ default=1,
+ required=True,
+ type=int,
+ help="ark index",
+ )
+ parser.add_argument(
+ "--output-dir",
+ "-o",
+ default=False,
+ required=True,
+ type=str,
+ help="output dir",
+ )
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ dump_ark_file = args.output_dir + "/feats." + str(args.ark_index) + ".ark"
+ dump_scp_file = args.output_dir + "/feats." + str(args.ark_index) + ".scp"
+ shape_file = args.output_dir + "/len." + str(args.ark_index)
+ ark_writer = WriteHelper('ark,scp:{},{}'.format(dump_ark_file, dump_scp_file))
+
+ shape_writer = open(shape_file, 'w')
+ with ReadHelper('ark:{}'.format(args.ark_file)) as ark_reader:
+ for key, mat in ark_reader:
+ if args.lfr:
+ lfr = build_LFR_features(mat, args.lfr_m, args.lfr_n)
+ else:
+ lfr = mat
+ cmvn = build_CMVN_features(lfr, args.cmvn_file)
+ dims = cmvn.shape[1]
+ lens = cmvn.shape[0]
+ shape_writer.write(key + " " + str(lens) + "," + str(dims) + '\n')
+ ark_writer(key, cmvn)
+
+
+if __name__ == '__main__':
+ main()
+
diff --git a/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.sh b/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.sh
new file mode 100755
index 0000000..3119fdb
--- /dev/null
+++ b/egs/aishell2/transformer/utils/apply_lfr_and_cmvn.sh
@@ -0,0 +1,38 @@
+#!/usr/bin/env bash
+
+
+# Begin configuration section.
+nj=32
+cmd=utils/run.pl
+
+# feature configuration
+lfr=True
+lfr_m=7
+lfr_n=6
+
+echo "$0 $@"
+
+. utils/parse_options.sh || exit 1;
+
+fbankdir=$1
+cmvn_file=$2
+logdir=$3
+output_dir=$4
+
+dump_dir=${output_dir}/ark; mkdir -p ${dump_dir}
+mkdir -p ${logdir}
+
+$cmd JOB=1:$nj $logdir/apply_lfr_and_cmvn.JOB.log \
+ python utils/apply_lfr_and_cmvn.py -a $fbankdir/ark/feats.JOB.ark \
+ -f $lfr -m $lfr_m -n $lfr_n -c $cmvn_file -i JOB -o ${dump_dir} \
+ || exit 1;
+
+for n in $(seq $nj); do
+ cat ${dump_dir}/feats.$n.scp || exit 1
+done > ${output_dir}/feats.scp || exit 1
+
+for n in $(seq $nj); do
+ cat ${dump_dir}/len.$n || exit 1
+done > ${output_dir}/speech_shape || exit 1
+
+echo "$0: Succeeded apply low frame rate and cmvn"
diff --git a/egs/aishell2/transformer/utils/cmvn_converter.py b/egs/aishell2/transformer/utils/cmvn_converter.py
new file mode 100644
index 0000000..d405d12
--- /dev/null
+++ b/egs/aishell2/transformer/utils/cmvn_converter.py
@@ -0,0 +1,51 @@
+import argparse
+import json
+import numpy as np
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="cmvn converter",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--cmvn_json",
+ default=False,
+ required=True,
+ type=str,
+ help="cmvn json file",
+ )
+ parser.add_argument(
+ "--am_mvn",
+ default=False,
+ required=True,
+ type=str,
+ help="am mvn file",
+ )
+ return parser
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ with open(args.cmvn_json, "r") as fin:
+ cmvn_dict = json.load(fin)
+
+ mean_stats = np.array(cmvn_dict["mean_stats"])
+ var_stats = np.array(cmvn_dict["var_stats"])
+ total_frame = np.array(cmvn_dict["total_frames"])
+
+ mean = -1.0 * mean_stats / total_frame
+ var = 1.0 / np.sqrt(var_stats / total_frame - mean * mean)
+ dims = mean.shape[0]
+ with open(args.am_mvn, 'w') as fout:
+ fout.write("<Nnet>" + "\n" + "<Splice> " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + "<AddShift> " + str(dims) + " " + str(dims) + "\n")
+ mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
+ fout.write("<LearnRateCoef> 0 " + mean_str + '\n')
+ fout.write("<Rescale> " + str(dims) + " " + str(dims) + '\n')
+ var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
+ fout.write("<LearnRateCoef> 0 " + var_str + '\n')
+ fout.write("</Nnet>" + '\n')
+
+if __name__ == '__main__':
+ main()
diff --git a/egs/aishell2/transformer/utils/combine_cmvn_file.py b/egs/aishell2/transformer/utils/combine_cmvn_file.py
new file mode 100755
index 0000000..c525973
--- /dev/null
+++ b/egs/aishell2/transformer/utils/combine_cmvn_file.py
@@ -0,0 +1,72 @@
+import argparse
+import json
+import os
+
+import numpy as np
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="combine cmvn file",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--dim",
+ default=80,
+ type=int,
+ help="feature dim",
+ )
+ parser.add_argument(
+ "--cmvn_dir",
+ default=False,
+ required=True,
+ type=str,
+ help="cmvn dir",
+ )
+
+ parser.add_argument(
+ "--nj",
+ default=1,
+ required=True,
+ type=int,
+ help="num of cmvn files",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=False,
+ required=True,
+ type=str,
+ help="output dir",
+ )
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ total_means = np.zeros(args.dim)
+ total_vars = np.zeros(args.dim)
+ total_frames = 0
+
+ cmvn_file = os.path.join(args.output_dir, "cmvn.json")
+
+ for i in range(1, args.nj + 1):
+ with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
+ cmvn_stats = json.load(fin)
+
+ total_means += np.array(cmvn_stats["mean_stats"])
+ total_vars += np.array(cmvn_stats["var_stats"])
+ total_frames += cmvn_stats["total_frames"]
+
+ cmvn_info = {
+ 'mean_stats': list(total_means.tolist()),
+ 'var_stats': list(total_vars.tolist()),
+ 'total_frames': total_frames
+ }
+ with open(cmvn_file, 'w') as fout:
+ fout.write(json.dumps(cmvn_info))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/egs/aishell2/transformer/utils/compute_cmvn.py b/egs/aishell2/transformer/utils/compute_cmvn.py
new file mode 100755
index 0000000..949cc08
--- /dev/null
+++ b/egs/aishell2/transformer/utils/compute_cmvn.py
@@ -0,0 +1,104 @@
+import argparse
+import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="computer global cmvn",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--dim",
+ default=80,
+ type=int,
+ help="feature dimension",
+ )
+ parser.add_argument(
+ "--wav_path",
+ default=False,
+ required=True,
+ type=str,
+ help="the path of wav scps",
+ )
+ parser.add_argument(
+ "--idx",
+ default=1,
+ required=True,
+ type=int,
+ help="index",
+ )
+ return parser
+
+
+def compute_fbank(wav_file,
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0,
+ resample_rate=16000,
+ speed=1.0,
+ window_type="hamming"):
+ waveform, sample_rate = torchaudio.load(wav_file)
+ if resample_rate != sample_rate:
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+ new_freq=resample_rate)(waveform)
+ if speed != 1.0:
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, resample_rate,
+ [['speed', str(speed)], ['rate', str(resample_rate)]]
+ )
+
+ waveform = waveform * (1 << 15)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ window_type=window_type,
+ sample_frequency=resample_rate)
+
+ return mat.numpy()
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
+ cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
+
+ mean_stats = np.zeros(args.dim)
+ var_stats = np.zeros(args.dim)
+ total_frames = 0
+
+ # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+ # for key, mat in ark_reader:
+ # mean_stats += np.sum(mat, axis=0)
+ # var_stats += np.sum(np.square(mat), axis=0)
+ # total_frames += mat.shape[0]
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ _, wav_file = line.strip().split()
+ fbank = compute_fbank(wav_file, num_mel_bins=args.dim)
+ mean_stats += np.sum(fbank, axis=0)
+ var_stats += np.sum(np.square(fbank), axis=0)
+ total_frames += fbank.shape[0]
+
+ cmvn_info = {
+ 'mean_stats': list(mean_stats.tolist()),
+ 'var_stats': list(var_stats.tolist()),
+ 'total_frames': total_frames
+ }
+ with open(cmvn_file, 'w') as fout:
+ fout.write(json.dumps(cmvn_info))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/egs/aishell2/transformer/utils/compute_cmvn.sh b/egs/aishell2/transformer/utils/compute_cmvn.sh
new file mode 100755
index 0000000..7663df9
--- /dev/null
+++ b/egs/aishell2/transformer/utils/compute_cmvn.sh
@@ -0,0 +1,34 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+# Begin configuration section.
+nj=32
+cmd=./utils/run.pl
+feats_dim=80
+
+echo "$0 $@"
+
+. utils/parse_options.sh || exit 1;
+
+fbankdir=$1
+
+split_dir=${fbankdir}/cmvn/split_${nj};
+mkdir -p $split_dir
+split_scps=""
+for n in $(seq $nj); do
+ split_scps="$split_scps $split_dir/wav.$n.scp"
+done
+utils/split_scp.pl ${fbankdir}/wav.scp $split_scps || exit 1;
+
+logdir=${fbankdir}/cmvn/log
+$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
+ python utils/compute_cmvn.py \
+ --dim ${feats_dim} \
+ --wav_path $split_dir \
+ --idx JOB
+
+python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
+
+python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/cmvn.mvn
+
+echo "$0: Succeeded compute global cmvn"
diff --git a/egs/aishell2/transformer/utils/compute_fbank.py b/egs/aishell2/transformer/utils/compute_fbank.py
new file mode 100755
index 0000000..9c3904f
--- /dev/null
+++ b/egs/aishell2/transformer/utils/compute_fbank.py
@@ -0,0 +1,171 @@
+from kaldiio import WriteHelper
+
+import argparse
+import numpy as np
+import json
+import torch
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
+
+
+def compute_fbank(wav_file,
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0,
+ resample_rate=16000,
+ speed=1.0,
+ window_type="hamming"):
+
+ waveform, sample_rate = torchaudio.load(wav_file)
+ if resample_rate != sample_rate:
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+ new_freq=resample_rate)(waveform)
+ if speed != 1.0:
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, resample_rate,
+ [['speed', str(speed)], ['rate', str(resample_rate)]]
+ )
+
+ waveform = waveform * (1 << 15)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ window_type=window_type,
+ sample_frequency=resample_rate)
+
+ return mat.numpy()
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="computer features",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--wav-lists",
+ "-w",
+ default=False,
+ required=True,
+ type=str,
+ help="input wav lists",
+ )
+ parser.add_argument(
+ "--text-files",
+ "-t",
+ default=False,
+ required=True,
+ type=str,
+ help="input text files",
+ )
+ parser.add_argument(
+ "--dims",
+ "-d",
+ default=80,
+ type=int,
+ help="feature dims",
+ )
+ parser.add_argument(
+ "--max-lengths",
+ "-m",
+ default=1500,
+ type=int,
+ help="max frame numbers",
+ )
+ parser.add_argument(
+ "--sample-frequency",
+ "-s",
+ default=16000,
+ type=int,
+ help="sample frequency",
+ )
+ parser.add_argument(
+ "--speed-perturb",
+ "-p",
+ default="1.0",
+ type=str,
+ help="speed perturb",
+ )
+ parser.add_argument(
+ "--ark-index",
+ "-a",
+ default=1,
+ required=True,
+ type=int,
+ help="ark index",
+ )
+ parser.add_argument(
+ "--output-dir",
+ "-o",
+ default=False,
+ required=True,
+ type=str,
+ help="output dir",
+ )
+ parser.add_argument(
+ "--window-type",
+ default="hamming",
+ required=False,
+ type=str,
+ help="window type"
+ )
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ ark_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".ark"
+ scp_file = args.output_dir + "/ark/feats." + str(args.ark_index) + ".scp"
+ text_file = args.output_dir + "/txt/text." + str(args.ark_index) + ".txt"
+ feats_shape_file = args.output_dir + "/ark/len." + str(args.ark_index)
+ text_shape_file = args.output_dir + "/txt/len." + str(args.ark_index)
+
+ ark_writer = WriteHelper('ark,scp:{},{}'.format(ark_file, scp_file))
+ text_writer = open(text_file, 'w')
+ feats_shape_writer = open(feats_shape_file, 'w')
+ text_shape_writer = open(text_shape_file, 'w')
+
+ speed_perturb_list = args.speed_perturb.split(',')
+
+ for speed in speed_perturb_list:
+ with open(args.wav_lists, 'r', encoding='utf-8') as wavfile:
+ with open(args.text_files, 'r', encoding='utf-8') as textfile:
+ for wav, text in zip(wavfile, textfile):
+ s_w = wav.strip().split()
+ wav_id = s_w[0]
+ wav_file = s_w[1]
+
+ s_t = text.strip().split()
+ text_id = s_t[0]
+ txt = s_t[1:]
+ fbank = compute_fbank(wav_file,
+ num_mel_bins=args.dims,
+ resample_rate=args.sample_frequency,
+ speed=float(speed),
+ window_type=args.window_type
+ )
+ feats_dims = fbank.shape[1]
+ feats_lens = fbank.shape[0]
+ if feats_lens >= args.max_lengths:
+ continue
+ txt_lens = len(txt)
+ if speed == "1.0":
+ wav_id_sp = wav_id
+ else:
+ wav_id_sp = wav_id + "_sp" + speed
+
+ feats_shape_writer.write(wav_id_sp + " " + str(feats_lens) + "," + str(feats_dims) + '\n')
+ text_shape_writer.write(wav_id_sp + " " + str(txt_lens) + '\n')
+
+ text_writer.write(wav_id_sp + " " + " ".join(txt) + '\n')
+ ark_writer(wav_id_sp, fbank)
+
+
+if __name__ == '__main__':
+ main()
+
diff --git a/egs/aishell2/transformer/utils/compute_fbank.sh b/egs/aishell2/transformer/utils/compute_fbank.sh
new file mode 100755
index 0000000..8704b31
--- /dev/null
+++ b/egs/aishell2/transformer/utils/compute_fbank.sh
@@ -0,0 +1,54 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+# Begin configuration section.
+nj=32
+cmd=./utils/run.pl
+
+# feature configuration
+feats_dim=80
+sample_frequency=16000
+speed_perturb="1.0"
+window_type="hamming"
+max_lengths=1500
+
+echo "$0 $@"
+
+. utils/parse_options.sh || exit 1;
+
+data=$1
+logdir=$2
+fbankdir=$3
+
+[ ! -f $data/wav.scp ] && echo "$0: no such file $data/wav.scp" && exit 1;
+[ ! -f $data/text ] && echo "$0: no such file $data/text" && exit 1;
+
+python utils/split_data.py $data $data $nj
+
+ark_dir=${fbankdir}/ark; mkdir -p ${ark_dir}
+text_dir=${fbankdir}/txt; mkdir -p ${text_dir}
+mkdir -p ${logdir}
+
+$cmd JOB=1:$nj $logdir/make_fbank.JOB.log \
+ python utils/compute_fbank.py -w $data/split${nj}/JOB/wav.scp -t $data/split${nj}/JOB/text \
+ -d $feats_dim -s $sample_frequency -m ${max_lengths} -p ${speed_perturb} -a JOB -o ${fbankdir} \
+ --window-type ${window_type} \
+ || exit 1;
+
+for n in $(seq $nj); do
+ cat ${ark_dir}/feats.$n.scp || exit 1
+done > $fbankdir/feats.scp || exit 1
+
+for n in $(seq $nj); do
+ cat ${text_dir}/text.$n.txt || exit 1
+done > $fbankdir/text || exit 1
+
+for n in $(seq $nj); do
+ cat ${ark_dir}/len.$n || exit 1
+done > $fbankdir/speech_shape || exit 1
+
+for n in $(seq $nj); do
+ cat ${text_dir}/len.$n || exit 1
+done > $fbankdir/text_shape || exit 1
+
+echo "$0: Succeeded compute FBANK features"
diff --git a/egs/aishell2/transformer/utils/compute_wer.py b/egs/aishell2/transformer/utils/compute_wer.py
new file mode 100755
index 0000000..26a9f49
--- /dev/null
+++ b/egs/aishell2/transformer/utils/compute_wer.py
@@ -0,0 +1,157 @@
+import os
+import numpy as np
+import sys
+
+def compute_wer(ref_file,
+ hyp_file,
+ cer_detail_file):
+ rst = {
+ 'Wrd': 0,
+ 'Corr': 0,
+ 'Ins': 0,
+ 'Del': 0,
+ 'Sub': 0,
+ 'Snt': 0,
+ 'Err': 0.0,
+ 'S.Err': 0.0,
+ 'wrong_words': 0,
+ 'wrong_sentences': 0
+ }
+
+ hyp_dict = {}
+ ref_dict = {}
+ with open(hyp_file, 'r') as hyp_reader:
+ for line in hyp_reader:
+ key = line.strip().split()[0]
+ value = line.strip().split()[1:]
+ hyp_dict[key] = value
+ with open(ref_file, 'r') as ref_reader:
+ for line in ref_reader:
+ key = line.strip().split()[0]
+ value = line.strip().split()[1:]
+ ref_dict[key] = value
+
+ cer_detail_writer = open(cer_detail_file, 'w')
+ for hyp_key in hyp_dict:
+ if hyp_key in ref_dict:
+ out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
+ rst['Wrd'] += out_item['nwords']
+ rst['Corr'] += out_item['cor']
+ rst['wrong_words'] += out_item['wrong']
+ rst['Ins'] += out_item['ins']
+ rst['Del'] += out_item['del']
+ rst['Sub'] += out_item['sub']
+ rst['Snt'] += 1
+ if out_item['wrong'] > 0:
+ rst['wrong_sentences'] += 1
+ cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
+ cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
+ cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
+
+ if rst['Wrd'] > 0:
+ rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
+ if rst['Snt'] > 0:
+ rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
+
+ cer_detail_writer.write('\n')
+ cer_detail_writer.write("%WER " + str(rst['Err']) + " [ " + str(rst['wrong_words'])+ " / " + str(rst['Wrd']) +
+ ", " + str(rst['Ins']) + " ins, " + str(rst['Del']) + " del, " + str(rst['Sub']) + " sub ]" + '\n')
+ cer_detail_writer.write("%SER " + str(rst['S.Err']) + " [ " + str(rst['wrong_sentences']) + " / " + str(rst['Snt']) + " ]" + '\n')
+ cer_detail_writer.write("Scored " + str(len(hyp_dict)) + " sentences, " + str(len(hyp_dict) - rst['Snt']) + " not present in hyp." + '\n')
+
+
+def compute_wer_by_line(hyp,
+ ref):
+ hyp = list(map(lambda x: x.lower(), hyp))
+ ref = list(map(lambda x: x.lower(), ref))
+
+ len_hyp = len(hyp)
+ len_ref = len(ref)
+
+ cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
+
+ ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
+
+ for i in range(len_hyp + 1):
+ cost_matrix[i][0] = i
+ for j in range(len_ref + 1):
+ cost_matrix[0][j] = j
+
+ for i in range(1, len_hyp + 1):
+ for j in range(1, len_ref + 1):
+ if hyp[i - 1] == ref[j - 1]:
+ cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
+ else:
+ substitution = cost_matrix[i - 1][j - 1] + 1
+ insertion = cost_matrix[i - 1][j] + 1
+ deletion = cost_matrix[i][j - 1] + 1
+
+ compare_val = [substitution, insertion, deletion]
+
+ min_val = min(compare_val)
+ operation_idx = compare_val.index(min_val) + 1
+ cost_matrix[i][j] = min_val
+ ops_matrix[i][j] = operation_idx
+
+ match_idx = []
+ i = len_hyp
+ j = len_ref
+ rst = {
+ 'nwords': len_ref,
+ 'cor': 0,
+ 'wrong': 0,
+ 'ins': 0,
+ 'del': 0,
+ 'sub': 0
+ }
+ while i >= 0 or j >= 0:
+ i_idx = max(0, i)
+ j_idx = max(0, j)
+
+ if ops_matrix[i_idx][j_idx] == 0: # correct
+ if i - 1 >= 0 and j - 1 >= 0:
+ match_idx.append((j - 1, i - 1))
+ rst['cor'] += 1
+
+ i -= 1
+ j -= 1
+
+ elif ops_matrix[i_idx][j_idx] == 2: # insert
+ i -= 1
+ rst['ins'] += 1
+
+ elif ops_matrix[i_idx][j_idx] == 3: # delete
+ j -= 1
+ rst['del'] += 1
+
+ elif ops_matrix[i_idx][j_idx] == 1: # substitute
+ i -= 1
+ j -= 1
+ rst['sub'] += 1
+
+ if i < 0 and j >= 0:
+ rst['del'] += 1
+ elif j < 0 and i >= 0:
+ rst['ins'] += 1
+
+ match_idx.reverse()
+ wrong_cnt = cost_matrix[len_hyp][len_ref]
+ rst['wrong'] = wrong_cnt
+
+ return rst
+
+def print_cer_detail(rst):
+ return ("(" + "nwords=" + str(rst['nwords']) + ",cor=" + str(rst['cor'])
+ + ",ins=" + str(rst['ins']) + ",del=" + str(rst['del']) + ",sub="
+ + str(rst['sub']) + ") corr:" + '{:.2%}'.format(rst['cor']/rst['nwords'])
+ + ",cer:" + '{:.2%}'.format(rst['wrong']/rst['nwords']))
+
+if __name__ == '__main__':
+ if len(sys.argv) != 4:
+ print("usage : python compute-wer.py test.ref test.hyp test.wer")
+ sys.exit(0)
+
+ ref_file = sys.argv[1]
+ hyp_file = sys.argv[2]
+ cer_detail_file = sys.argv[3]
+ compute_wer(ref_file, hyp_file, cer_detail_file)
diff --git a/egs/aishell2/transformer/utils/download_model.py b/egs/aishell2/transformer/utils/download_model.py
new file mode 100755
index 0000000..70ea179
--- /dev/null
+++ b/egs/aishell2/transformer/utils/download_model.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python3
+import argparse
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description="download model configs",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument("--model_name",
+ type=str,
+ default="damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch",
+ help="model name in ModelScope")
+ args = parser.parse_args()
+
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=args.model_name)
diff --git a/egs/aishell2/transformer/utils/error_rate_zh b/egs/aishell2/transformer/utils/error_rate_zh
new file mode 100755
index 0000000..6871a07
--- /dev/null
+++ b/egs/aishell2/transformer/utils/error_rate_zh
@@ -0,0 +1,370 @@
+#!/usr/bin/env python3
+# coding=utf8
+
+# Copyright 2021 Jiayu DU
+
+import sys
+import argparse
+import json
+import logging
+logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s')
+
+DEBUG = None
+
+def GetEditType(ref_token, hyp_token):
+ if ref_token == None and hyp_token != None:
+ return 'I'
+ elif ref_token != None and hyp_token == None:
+ return 'D'
+ elif ref_token == hyp_token:
+ return 'C'
+ elif ref_token != hyp_token:
+ return 'S'
+ else:
+ raise RuntimeError
+
+class AlignmentArc:
+ def __init__(self, src, dst, ref, hyp):
+ self.src = src
+ self.dst = dst
+ self.ref = ref
+ self.hyp = hyp
+ self.edit_type = GetEditType(ref, hyp)
+
+def similarity_score_function(ref_token, hyp_token):
+ return 0 if (ref_token == hyp_token) else -1.0
+
+def insertion_score_function(token):
+ return -1.0
+
+def deletion_score_function(token):
+ return -1.0
+
+def EditDistance(
+ ref,
+ hyp,
+ similarity_score_function = similarity_score_function,
+ insertion_score_function = insertion_score_function,
+ deletion_score_function = deletion_score_function):
+ assert(len(ref) != 0)
+ class DPState:
+ def __init__(self):
+ self.score = -float('inf')
+ # backpointer
+ self.prev_r = None
+ self.prev_h = None
+
+ def print_search_grid(S, R, H, fstream):
+ print(file=fstream)
+ for r in range(R):
+ for h in range(H):
+ print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream)
+ print(file=fstream)
+
+ R = len(ref) + 1
+ H = len(hyp) + 1
+
+ # Construct DP search space, a (R x H) grid
+ S = [ [] for r in range(R) ]
+ for r in range(R):
+ S[r] = [ DPState() for x in range(H) ]
+
+ # initialize DP search grid origin, S(r = 0, h = 0)
+ S[0][0].score = 0.0
+ S[0][0].prev_r = None
+ S[0][0].prev_h = None
+
+ # initialize REF axis
+ for r in range(1, R):
+ S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1])
+ S[r][0].prev_r = r-1
+ S[r][0].prev_h = 0
+
+ # initialize HYP axis
+ for h in range(1, H):
+ S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1])
+ S[0][h].prev_r = 0
+ S[0][h].prev_h = h-1
+
+ best_score = S[0][0].score
+ best_state = (0, 0)
+
+ for r in range(1, R):
+ for h in range(1, H):
+ sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1])
+ new_score = S[r-1][h-1].score + sub_or_cor_score
+ if new_score >= S[r][h].score:
+ S[r][h].score = new_score
+ S[r][h].prev_r = r-1
+ S[r][h].prev_h = h-1
+
+ del_score = deletion_score_function(ref[r-1])
+ new_score = S[r-1][h].score + del_score
+ if new_score >= S[r][h].score:
+ S[r][h].score = new_score
+ S[r][h].prev_r = r - 1
+ S[r][h].prev_h = h
+
+ ins_score = insertion_score_function(hyp[h-1])
+ new_score = S[r][h-1].score + ins_score
+ if new_score >= S[r][h].score:
+ S[r][h].score = new_score
+ S[r][h].prev_r = r
+ S[r][h].prev_h = h-1
+
+ best_score = S[R-1][H-1].score
+ best_state = (R-1, H-1)
+
+ if DEBUG:
+ print_search_grid(S, R, H, sys.stderr)
+
+ # Backtracing best alignment path, i.e. a list of arcs
+ # arc = (src, dst, ref, hyp, edit_type)
+ # src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis
+ best_path = []
+ r, h = best_state[0], best_state[1]
+ prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
+ score = S[r][h].score
+ # loop invariant:
+ # 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path
+ # 2. score is the value of point(r, h) on DP search grid
+ while prev_r != None or prev_h != None:
+ src = (prev_r, prev_h)
+ dst = (r, h)
+ if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct
+ arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h])
+ elif (r == prev_r + 1 and h == prev_h): # Deletion
+ arc = AlignmentArc(src, dst, ref[prev_r], None)
+ elif (r == prev_r and h == prev_h + 1): # Insertion
+ arc = AlignmentArc(src, dst, None, hyp[prev_h])
+ else:
+ raise RuntimeError
+ best_path.append(arc)
+ r, h = prev_r, prev_h
+ prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
+ score = S[r][h].score
+
+ best_path.reverse()
+ return (best_path, best_score)
+
+def PrettyPrintAlignment(alignment, stream = sys.stderr):
+ def get_token_str(token):
+ if token == None:
+ return "*"
+ return token
+
+ def is_double_width_char(ch):
+ if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars
+ return True
+ # TODO: support other double-width-char language such as Japanese, Korean
+ else:
+ return False
+
+ def display_width(token_str):
+ m = 0
+ for c in token_str:
+ if is_double_width_char(c):
+ m += 2
+ else:
+ m += 1
+ return m
+
+ R = ' REF : '
+ H = ' HYP : '
+ E = ' EDIT : '
+ for arc in alignment:
+ r = get_token_str(arc.ref)
+ h = get_token_str(arc.hyp)
+ e = arc.edit_type if arc.edit_type != 'C' else ''
+
+ nr, nh, ne = display_width(r), display_width(h), display_width(e)
+ n = max(nr, nh, ne) + 1
+
+ R += r + ' ' * (n-nr)
+ H += h + ' ' * (n-nh)
+ E += e + ' ' * (n-ne)
+
+ print(R, file=stream)
+ print(H, file=stream)
+ print(E, file=stream)
+
+def CountEdits(alignment):
+ c, s, i, d = 0, 0, 0, 0
+ for arc in alignment:
+ if arc.edit_type == 'C':
+ c += 1
+ elif arc.edit_type == 'S':
+ s += 1
+ elif arc.edit_type == 'I':
+ i += 1
+ elif arc.edit_type == 'D':
+ d += 1
+ else:
+ raise RuntimeError
+ return (c, s, i, d)
+
+def ComputeTokenErrorRate(c, s, i, d):
+ return 100.0 * (s + d + i) / (s + d + c)
+
+def ComputeSentenceErrorRate(num_err_utts, num_utts):
+ assert(num_utts != 0)
+ return 100.0 * num_err_utts / num_utts
+
+
+class EvaluationResult:
+ def __init__(self):
+ self.num_ref_utts = 0
+ self.num_hyp_utts = 0
+ self.num_eval_utts = 0 # seen in both ref & hyp
+ self.num_hyp_without_ref = 0
+
+ self.C = 0
+ self.S = 0
+ self.I = 0
+ self.D = 0
+ self.token_error_rate = 0.0
+
+ self.num_utts_with_error = 0
+ self.sentence_error_rate = 0.0
+
+ def to_json(self):
+ return json.dumps(self.__dict__)
+
+ def to_kaldi(self):
+ info = (
+ F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
+ F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
+ )
+ return info
+
+ def to_sclite(self):
+ return "TODO"
+
+ def to_espnet(self):
+ return "TODO"
+
+ def to_summary(self):
+ #return json.dumps(self.__dict__, indent=4)
+ summary = (
+ '==================== Overall Statistics ====================\n'
+ F'num_ref_utts: {self.num_ref_utts}\n'
+ F'num_hyp_utts: {self.num_hyp_utts}\n'
+ F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
+ F'num_eval_utts: {self.num_eval_utts}\n'
+ F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
+ F'token_error_rate: {self.token_error_rate:.2f}%\n'
+ F'token_stats:\n'
+ F' - tokens:{self.C + self.S + self.D:>7}\n'
+ F' - edits: {self.S + self.I + self.D:>7}\n'
+ F' - cor: {self.C:>7}\n'
+ F' - sub: {self.S:>7}\n'
+ F' - ins: {self.I:>7}\n'
+ F' - del: {self.D:>7}\n'
+ '============================================================\n'
+ )
+ return summary
+
+
+class Utterance:
+ def __init__(self, uid, text):
+ self.uid = uid
+ self.text = text
+
+
+def LoadUtterances(filepath, format):
+ utts = {}
+ if format == 'text': # utt_id word1 word2 ...
+ with open(filepath, 'r', encoding='utf8') as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ cols = line.split(maxsplit=1)
+ assert(len(cols) == 2 or len(cols) == 1)
+ uid = cols[0]
+ text = cols[1] if len(cols) == 2 else ''
+ if utts.get(uid) != None:
+ raise RuntimeError(F'Found duplicated utterence id {uid}')
+ utts[uid] = Utterance(uid, text)
+ else:
+ raise RuntimeError(F'Unsupported text format {format}')
+ return utts
+
+
+def tokenize_text(text, tokenizer):
+ if tokenizer == 'whitespace':
+ return text.split()
+ elif tokenizer == 'char':
+ return [ ch for ch in ''.join(text.split()) ]
+ else:
+ raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ # optional
+ parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER')
+ parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text')
+ parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text')
+ # required
+ parser.add_argument('--ref', type=str, required=True, help='input reference file')
+ parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file')
+
+ parser.add_argument('result_file', type=str)
+ args = parser.parse_args()
+ logging.info(args)
+
+ ref_utts = LoadUtterances(args.ref, args.ref_format)
+ hyp_utts = LoadUtterances(args.hyp, args.hyp_format)
+
+ r = EvaluationResult()
+
+ # check valid utterances in hyp that have matched non-empty reference
+ eval_utts = []
+ r.num_hyp_without_ref = 0
+ for uid in sorted(hyp_utts.keys()):
+ if uid in ref_utts.keys(): # TODO: efficiency
+ if ref_utts[uid].text.strip(): # non-empty reference
+ eval_utts.append(uid)
+ else:
+ logging.warn(F'Found {uid} with empty reference, skipping...')
+ else:
+ logging.warn(F'Found {uid} without reference, skipping...')
+ r.num_hyp_without_ref += 1
+
+ r.num_hyp_utts = len(hyp_utts)
+ r.num_ref_utts = len(ref_utts)
+ r.num_eval_utts = len(eval_utts)
+
+ with open(args.result_file, 'w+', encoding='utf8') as fo:
+ for uid in eval_utts:
+ ref = ref_utts[uid]
+ hyp = hyp_utts[uid]
+
+ alignment, score = EditDistance(
+ tokenize_text(ref.text, args.tokenizer),
+ tokenize_text(hyp.text, args.tokenizer)
+ )
+
+ c, s, i, d = CountEdits(alignment)
+ utt_ter = ComputeTokenErrorRate(c, s, i, d)
+
+ # utt-level evaluation result
+ print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo)
+ PrettyPrintAlignment(alignment, fo)
+
+ r.C += c
+ r.S += s
+ r.I += i
+ r.D += d
+
+ if utt_ter > 0:
+ r.num_utts_with_error += 1
+
+ # corpus level evaluation result
+ r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts)
+ r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D)
+
+ print(r.to_summary(), file=fo)
+
+ print(r.to_json())
+ print(r.to_kaldi())
diff --git a/egs/aishell2/transformer/utils/extract_embeds.py b/egs/aishell2/transformer/utils/extract_embeds.py
new file mode 100755
index 0000000..7b817d8
--- /dev/null
+++ b/egs/aishell2/transformer/utils/extract_embeds.py
@@ -0,0 +1,47 @@
+from transformers import AutoTokenizer, AutoModel, pipeline
+import numpy as np
+import sys
+import os
+import torch
+from kaldiio import WriteHelper
+import re
+text_file_json = sys.argv[1]
+out_ark = sys.argv[2]
+out_scp = sys.argv[3]
+out_shape = sys.argv[4]
+device = int(sys.argv[5])
+model_path = sys.argv[6]
+
+model = AutoModel.from_pretrained(model_path)
+tokenizer = AutoTokenizer.from_pretrained(model_path)
+extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
+
+with open(text_file_json, 'r') as f:
+ js = f.readlines()
+
+
+f_shape = open(out_shape, "w")
+with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
+ with torch.no_grad():
+ for idx, line in enumerate(js):
+ id, tokens = line.strip().split(" ", 1)
+ tokens = re.sub(" ", "", tokens.strip())
+ tokens = ' '.join([j for j in tokens])
+ token_num = len(tokens.split(" "))
+ outputs = extractor(tokens)
+ outputs = np.array(outputs)
+ embeds = outputs[0, 1:-1, :]
+
+ token_num_embeds, dim = embeds.shape
+ if token_num == token_num_embeds:
+ writer(id, embeds)
+ shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
+ f_shape.write(shape_line)
+ else:
+ print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
+
+
+
+f_shape.close()
+
+
diff --git a/egs/aishell2/transformer/utils/filter_scp.pl b/egs/aishell2/transformer/utils/filter_scp.pl
new file mode 100755
index 0000000..003530d
--- /dev/null
+++ b/egs/aishell2/transformer/utils/filter_scp.pl
@@ -0,0 +1,87 @@
+#!/usr/bin/env perl
+# Copyright 2010-2012 Microsoft Corporation
+# Johns Hopkins University (author: Daniel Povey)
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# This script takes a list of utterance-ids or any file whose first field
+# of each line is an utterance-id, and filters an scp
+# file (or any file whose "n-th" field is an utterance id), printing
+# out only those lines whose "n-th" field is in id_list. The index of
+# the "n-th" field is 1, by default, but can be changed by using
+# the -f <n> switch
+
+$exclude = 0;
+$field = 1;
+$shifted = 0;
+
+do {
+ $shifted=0;
+ if ($ARGV[0] eq "--exclude") {
+ $exclude = 1;
+ shift @ARGV;
+ $shifted=1;
+ }
+ if ($ARGV[0] eq "-f") {
+ $field = $ARGV[1];
+ shift @ARGV; shift @ARGV;
+ $shifted=1
+ }
+} while ($shifted);
+
+if(@ARGV < 1 || @ARGV > 2) {
+ die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
+ "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
+ "Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
+ "only the lines that were *not* in id_list.\n" .
+ "Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
+ "If your older scripts (written before Oct 2014) stopped working and you used the\n" .
+ "-f option, add 1 to the argument.\n" .
+ "See also: scripts/filter_scp.pl .\n";
+}
+
+
+$idlist = shift @ARGV;
+open(F, "<$idlist") || die "Could not open id-list file $idlist";
+while(<F>) {
+ @A = split;
+ @A>=1 || die "Invalid id-list file line $_";
+ $seen{$A[0]} = 1;
+}
+
+if ($field == 1) { # Treat this as special case, since it is common.
+ while(<>) {
+ $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
+ # $1 is what we filter on.
+ if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
+ print $_;
+ }
+ }
+} else {
+ while(<>) {
+ @A = split;
+ @A > 0 || die "Invalid scp file line $_";
+ @A >= $field || die "Invalid scp file line $_";
+ if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
+ print $_;
+ }
+ }
+}
+
+# tests:
+# the following should print "foo 1"
+# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl <(echo foo)
+# the following should print "bar 2".
+# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl -f 2 <(echo 2)
diff --git a/egs/aishell2/transformer/utils/fix_data.sh b/egs/aishell2/transformer/utils/fix_data.sh
new file mode 100755
index 0000000..b1a2bb8
--- /dev/null
+++ b/egs/aishell2/transformer/utils/fix_data.sh
@@ -0,0 +1,35 @@
+#!/usr/bin/env bash
+
+echo "$0 $@"
+data_dir=$1
+
+if [ ! -f ${data_dir}/wav.scp ]; then
+ echo "$0: wav.scp is not found"
+ exit 1;
+fi
+
+if [ ! -f ${data_dir}/text ]; then
+ echo "$0: text is not found"
+ exit 1;
+fi
+
+
+
+mkdir -p ${data_dir}/.backup
+
+awk '{print $1}' ${data_dir}/wav.scp > ${data_dir}/.backup/wav_id
+awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
+
+sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
+
+cp ${data_dir}/wav.scp ${data_dir}/.backup/wav.scp
+cp ${data_dir}/text ${data_dir}/.backup/text
+
+mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak
+mv ${data_dir}/text ${data_dir}/text.bak
+
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak | sort -k1,1 -u > ${data_dir}/wav.scp
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
+
+rm ${data_dir}/wav.scp.bak
+rm ${data_dir}/text.bak
diff --git a/egs/aishell2/transformer/utils/fix_data_feat.sh b/egs/aishell2/transformer/utils/fix_data_feat.sh
new file mode 100755
index 0000000..84eea36
--- /dev/null
+++ b/egs/aishell2/transformer/utils/fix_data_feat.sh
@@ -0,0 +1,52 @@
+#!/usr/bin/env bash
+
+echo "$0 $@"
+data_dir=$1
+
+if [ ! -f ${data_dir}/feats.scp ]; then
+ echo "$0: feats.scp is not found"
+ exit 1;
+fi
+
+if [ ! -f ${data_dir}/text ]; then
+ echo "$0: text is not found"
+ exit 1;
+fi
+
+if [ ! -f ${data_dir}/speech_shape ]; then
+ echo "$0: feature lengths is not found"
+ exit 1;
+fi
+
+if [ ! -f ${data_dir}/text_shape ]; then
+ echo "$0: text lengths is not found"
+ exit 1;
+fi
+
+mkdir -p ${data_dir}/.backup
+
+awk '{print $1}' ${data_dir}/feats.scp > ${data_dir}/.backup/wav_id
+awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
+
+sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
+
+cp ${data_dir}/feats.scp ${data_dir}/.backup/feats.scp
+cp ${data_dir}/text ${data_dir}/.backup/text
+cp ${data_dir}/speech_shape ${data_dir}/.backup/speech_shape
+cp ${data_dir}/text_shape ${data_dir}/.backup/text_shape
+
+mv ${data_dir}/feats.scp ${data_dir}/feats.scp.bak
+mv ${data_dir}/text ${data_dir}/text.bak
+mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak
+mv ${data_dir}/text_shape ${data_dir}/text_shape.bak
+
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak | sort -k1,1 -u > ${data_dir}/feats.scp
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak | sort -k1,1 -u > ${data_dir}/speech_shape
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak | sort -k1,1 -u > ${data_dir}/text_shape
+
+rm ${data_dir}/feats.scp.bak
+rm ${data_dir}/text.bak
+rm ${data_dir}/speech_shape.bak
+rm ${data_dir}/text_shape.bak
+
diff --git a/egs/aishell2/transformer/utils/gen_ark_list.sh b/egs/aishell2/transformer/utils/gen_ark_list.sh
new file mode 100755
index 0000000..aebf356
--- /dev/null
+++ b/egs/aishell2/transformer/utils/gen_ark_list.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+
+
+# Begin configuration section.
+nj=32
+cmd=./utils/run.pl
+
+echo "$0 $@"
+
+. utils/parse_options.sh || exit 1;
+
+ark_dir=$1
+txt_dir=$2
+output_dir=$3
+
+[ ! -d ${ark_dir}/ark ] && echo "$0: ark data is required" && exit 1;
+[ ! -d ${txt_dir}/txt ] && echo "$0: txt data is required" && exit 1;
+
+for n in $(seq $nj); do
+ echo "${ark_dir}/ark/feats.$n.ark ${txt_dir}/txt/text.$n.txt" || exit 1
+done > ${output_dir}/ark_txt.scp || exit 1
+
diff --git a/egs/aishell2/transformer/utils/parse_options.sh b/egs/aishell2/transformer/utils/parse_options.sh
new file mode 100755
index 0000000..71fb9e5
--- /dev/null
+++ b/egs/aishell2/transformer/utils/parse_options.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
+# Arnab Ghoshal, Karel Vesely
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Parse command-line options.
+# To be sourced by another script (as in ". parse_options.sh").
+# Option format is: --option-name arg
+# and shell variable "option_name" gets set to value "arg."
+# The exception is --help, which takes no arguments, but prints the
+# $help_message variable (if defined).
+
+
+###
+### The --config file options have lower priority to command line
+### options, so we need to import them first...
+###
+
+# Now import all the configs specified by command-line, in left-to-right order
+for ((argpos=1; argpos<$#; argpos++)); do
+ if [ "${!argpos}" == "--config" ]; then
+ argpos_plus1=$((argpos+1))
+ config=${!argpos_plus1}
+ [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
+ . $config # source the config file.
+ fi
+done
+
+
+###
+### Now we process the command line options
+###
+while true; do
+ [ -z "${1:-}" ] && break; # break if there are no arguments
+ case "$1" in
+ # If the enclosing script is called with --help option, print the help
+ # message and exit. Scripts should put help messages in $help_message
+ --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
+ else printf "$help_message\n" 1>&2 ; fi;
+ exit 0 ;;
+ --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
+ exit 1 ;;
+ # If the first command-line argument begins with "--" (e.g. --foo-bar),
+ # then work out the variable name as $name, which will equal "foo_bar".
+ --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
+ # Next we test whether the variable in question is undefned-- if so it's
+ # an invalid option and we die. Note: $0 evaluates to the name of the
+ # enclosing script.
+ # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
+ # is undefined. We then have to wrap this test inside "eval" because
+ # foo_bar is itself inside a variable ($name).
+ eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
+
+ oldval="`eval echo \\$$name`";
+ # Work out whether we seem to be expecting a Boolean argument.
+ if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
+ was_bool=true;
+ else
+ was_bool=false;
+ fi
+
+ # Set the variable to the right value-- the escaped quotes make it work if
+ # the option had spaces, like --cmd "queue.pl -sync y"
+ eval $name=\"$2\";
+
+ # Check that Boolean-valued arguments are really Boolean.
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
+ exit 1;
+ fi
+ shift 2;
+ ;;
+ *) break;
+ esac
+done
+
+
+# Check for an empty argument to the --cmd option, which can easily occur as a
+# result of scripting errors.
+[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
+
+
+true; # so this script returns exit code 0.
diff --git a/egs/aishell2/transformer/utils/print_args.py b/egs/aishell2/transformer/utils/print_args.py
new file mode 100755
index 0000000..b0c61e5
--- /dev/null
+++ b/egs/aishell2/transformer/utils/print_args.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+import sys
+
+
+def get_commandline_args(no_executable=True):
+ extra_chars = [
+ " ",
+ ";",
+ "&",
+ "|",
+ "<",
+ ">",
+ "?",
+ "*",
+ "~",
+ "`",
+ '"',
+ "'",
+ "\\",
+ "{",
+ "}",
+ "(",
+ ")",
+ ]
+
+ # Escape the extra characters for shell
+ argv = [
+ arg.replace("'", "'\\''")
+ if all(char not in arg for char in extra_chars)
+ else "'" + arg.replace("'", "'\\''") + "'"
+ for arg in sys.argv
+ ]
+
+ if no_executable:
+ return " ".join(argv[1:])
+ else:
+ return sys.executable + " " + " ".join(argv)
+
+
+def main():
+ print(get_commandline_args())
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell2/transformer/utils/proc_conf_oss.py b/egs/aishell2/transformer/utils/proc_conf_oss.py
new file mode 100755
index 0000000..c4a90c5
--- /dev/null
+++ b/egs/aishell2/transformer/utils/proc_conf_oss.py
@@ -0,0 +1,35 @@
+from pathlib import Path
+
+import torch
+import yaml
+
+
+class NoAliasSafeDumper(yaml.SafeDumper):
+ # Disable anchor/alias in yaml because looks ugly
+ def ignore_aliases(self, data):
+ return True
+
+
+def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
+ """Safe-dump in yaml with no anchor/alias"""
+ return yaml.dump(
+ data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
+ )
+
+
+def gen_conf(file, out_dir):
+ conf = torch.load(file)["config"]
+ conf["oss_bucket"] = "null"
+ print(conf)
+ output_dir = Path(out_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ with (output_dir / "config.yaml").open("w", encoding="utf-8") as f:
+ yaml_no_alias_safe_dump(conf, f, indent=4, sort_keys=False)
+
+
+if __name__ == "__main__":
+ import sys
+
+ in_f = sys.argv[1]
+ out_f = sys.argv[2]
+ gen_conf(in_f, out_f)
diff --git a/egs/aishell2/transformer/utils/proce_text.py b/egs/aishell2/transformer/utils/proce_text.py
new file mode 100755
index 0000000..9e517a4
--- /dev/null
+++ b/egs/aishell2/transformer/utils/proce_text.py
@@ -0,0 +1,31 @@
+
+import sys
+import re
+
+in_f = sys.argv[1]
+out_f = sys.argv[2]
+
+
+with open(in_f, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+
+with open(out_f, "w", encoding="utf-8") as f:
+ for line in lines:
+ outs = line.strip().split(" ", 1)
+ if len(outs) == 2:
+ idx, text = outs
+ text = re.sub("</s>", "", text)
+ text = re.sub("<s>", "", text)
+ text = re.sub("@@", "", text)
+ text = re.sub("@", "", text)
+ text = re.sub("<unk>", "", text)
+ text = re.sub(" ", "", text)
+ text = text.lower()
+ else:
+ idx = outs[0]
+ text = " "
+
+ text = [x for x in text]
+ text = " ".join(text)
+ out = "{} {}\n".format(idx, text)
+ f.write(out)
diff --git a/egs/aishell2/transformer/utils/run.pl b/egs/aishell2/transformer/utils/run.pl
new file mode 100755
index 0000000..483f95b
--- /dev/null
+++ b/egs/aishell2/transformer/utils/run.pl
@@ -0,0 +1,356 @@
+#!/usr/bin/env perl
+use warnings; #sed replacement for -w perl parameter
+# In general, doing
+# run.pl some.log a b c is like running the command a b c in
+# the bash shell, and putting the standard error and output into some.log.
+# To run parallel jobs (backgrounded on the host machine), you can do (e.g.)
+# run.pl JOB=1:4 some.JOB.log a b c JOB is like running the command a b c JOB
+# and putting it in some.JOB.log, for each one. [Note: JOB can be any identifier].
+# If any of the jobs fails, this script will fail.
+
+# A typical example is:
+# run.pl some.log my-prog "--opt=foo bar" foo \| other-prog baz
+# and run.pl will run something like:
+# ( my-prog '--opt=foo bar' foo | other-prog baz ) >& some.log
+#
+# Basically it takes the command-line arguments, quotes them
+# as necessary to preserve spaces, and evaluates them with bash.
+# In addition it puts the command line at the top of the log, and
+# the start and end times of the command at the beginning and end.
+# The reason why this is useful is so that we can create a different
+# version of this program that uses a queueing system instead.
+
+#use Data::Dumper;
+
+@ARGV < 2 && die "usage: run.pl log-file command-line arguments...";
+
+#print STDERR "COMMAND-LINE: " . Dumper(\@ARGV) . "\n";
+$job_pick = 'all';
+$max_jobs_run = -1;
+$jobstart = 1;
+$jobend = 1;
+$ignored_opts = ""; # These will be ignored.
+
+# First parse an option like JOB=1:4, and any
+# options that would normally be given to
+# queue.pl, which we will just discard.
+
+for (my $x = 1; $x <= 2; $x++) { # This for-loop is to
+ # allow the JOB=1:n option to be interleaved with the
+ # options to qsub.
+ while (@ARGV >= 2 && $ARGV[0] =~ m:^-:) {
+ # parse any options that would normally go to qsub, but which will be ignored here.
+ my $switch = shift @ARGV;
+ if ($switch eq "-V") {
+ $ignored_opts .= "-V ";
+ } elsif ($switch eq "--max-jobs-run" || $switch eq "-tc") {
+ # we do support the option --max-jobs-run n, and its GridEngine form -tc n.
+ # if the command appears multiple times uses the smallest option.
+ if ( $max_jobs_run <= 0 ) {
+ $max_jobs_run = shift @ARGV;
+ } else {
+ my $new_constraint = shift @ARGV;
+ if ( ($new_constraint < $max_jobs_run) ) {
+ $max_jobs_run = $new_constraint;
+ }
+ }
+
+ if (! ($max_jobs_run > 0)) {
+ die "run.pl: invalid option --max-jobs-run $max_jobs_run";
+ }
+ } else {
+ my $argument = shift @ARGV;
+ if ($argument =~ m/^--/) {
+ print STDERR "run.pl: WARNING: suspicious argument '$argument' to $switch; starts with '-'\n";
+ }
+ if ($switch eq "-sync" && $argument =~ m/^[yY]/) {
+ $ignored_opts .= "-sync "; # Note: in the
+ # corresponding code in queue.pl it says instead, just "$sync = 1;".
+ } elsif ($switch eq "-pe") { # e.g. -pe smp 5
+ my $argument2 = shift @ARGV;
+ $ignored_opts .= "$switch $argument $argument2 ";
+ } elsif ($switch eq "--gpu") {
+ $using_gpu = $argument;
+ } elsif ($switch eq "--pick") {
+ if($argument =~ m/^(all|failed|incomplete)$/) {
+ $job_pick = $argument;
+ } else {
+ print STDERR "run.pl: ERROR: --pick argument must be one of 'all', 'failed' or 'incomplete'"
+ }
+ } else {
+ # Ignore option.
+ $ignored_opts .= "$switch $argument ";
+ }
+ }
+ }
+ if ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+):(\d+)$/) { # e.g. JOB=1:20
+ $jobname = $1;
+ $jobstart = $2;
+ $jobend = $3;
+ if ($jobstart > $jobend) {
+ die "run.pl: invalid job range $ARGV[0]";
+ }
+ if ($jobstart <= 0) {
+ die "run.pl: invalid job range $ARGV[0], start must be strictly positive (this is required for GridEngine compatibility).";
+ }
+ shift;
+ } elsif ($ARGV[0] =~ m/^([\w_][\w\d_]*)+=(\d+)$/) { # e.g. JOB=1.
+ $jobname = $1;
+ $jobstart = $2;
+ $jobend = $2;
+ shift;
+ } elsif ($ARGV[0] =~ m/.+\=.*\:.*$/) {
+ print STDERR "run.pl: Warning: suspicious first argument to run.pl: $ARGV[0]\n";
+ }
+}
+
+# Users found this message confusing so we are removing it.
+# if ($ignored_opts ne "") {
+# print STDERR "run.pl: Warning: ignoring options \"$ignored_opts\"\n";
+# }
+
+if ($max_jobs_run == -1) { # If --max-jobs-run option not set,
+ # then work out the number of processors if possible,
+ # and set it based on that.
+ $max_jobs_run = 0;
+ if ($using_gpu) {
+ if (open(P, "nvidia-smi -L |")) {
+ $max_jobs_run++ while (<P>);
+ close(P);
+ }
+ if ($max_jobs_run == 0) {
+ $max_jobs_run = 1;
+ print STDERR "run.pl: Warning: failed to detect number of GPUs from nvidia-smi, using ${max_jobs_run}\n";
+ }
+ } elsif (open(P, "</proc/cpuinfo")) { # Linux
+ while (<P>) { if (m/^processor/) { $max_jobs_run++; } }
+ if ($max_jobs_run == 0) {
+ print STDERR "run.pl: Warning: failed to detect any processors from /proc/cpuinfo\n";
+ $max_jobs_run = 10; # reasonable default.
+ }
+ close(P);
+ } elsif (open(P, "sysctl -a |")) { # BSD/Darwin
+ while (<P>) {
+ if (m/hw\.ncpu\s*[:=]\s*(\d+)/) { # hw.ncpu = 4, or hw.ncpu: 4
+ $max_jobs_run = $1;
+ last;
+ }
+ }
+ close(P);
+ if ($max_jobs_run == 0) {
+ print STDERR "run.pl: Warning: failed to detect any processors from sysctl -a\n";
+ $max_jobs_run = 10; # reasonable default.
+ }
+ } else {
+ # allow at most 32 jobs at once, on non-UNIX systems; change this code
+ # if you need to change this default.
+ $max_jobs_run = 32;
+ }
+ # The just-computed value of $max_jobs_run is just the number of processors
+ # (or our best guess); and if it happens that the number of jobs we need to
+ # run is just slightly above $max_jobs_run, it will make sense to increase
+ # $max_jobs_run to equal the number of jobs, so we don't have a small number
+ # of leftover jobs.
+ $num_jobs = $jobend - $jobstart + 1;
+ if (!$using_gpu &&
+ $num_jobs > $max_jobs_run && $num_jobs < 1.4 * $max_jobs_run) {
+ $max_jobs_run = $num_jobs;
+ }
+}
+
+sub pick_or_exit {
+ # pick_or_exit ( $logfile )
+ # Invoked before each job is started helps to run jobs selectively.
+ #
+ # Given the name of the output logfile decides whether the job must be
+ # executed (by returning from the subroutine) or not (by terminating the
+ # process calling exit)
+ #
+ # PRE: $job_pick is a global variable set by command line switch --pick
+ # and indicates which class of jobs must be executed.
+ #
+ # 1) If a failed job is not executed the process exit code will indicate
+ # failure, just as if the task was just executed and failed.
+ #
+ # 2) If a task is incomplete it will be executed. Incomplete may be either
+ # a job whose log file does not contain the accounting notes in the end,
+ # or a job whose log file does not exist.
+ #
+ # 3) If the $job_pick is set to 'all' (default behavior) a task will be
+ # executed regardless of the result of previous attempts.
+ #
+ # This logic could have been implemented in the main execution loop
+ # but a subroutine to preserve the current level of readability of
+ # that part of the code.
+ #
+ # Alexandre Felipe, (o.alexandre.felipe@gmail.com) 14th of August of 2020
+ #
+ if($job_pick eq 'all'){
+ return; # no need to bother with the previous log
+ }
+ open my $fh, "<", $_[0] or return; # job not executed yet
+ my $log_line;
+ my $cur_line;
+ while ($cur_line = <$fh>) {
+ if( $cur_line =~ m/# Ended \(code .*/ ) {
+ $log_line = $cur_line;
+ }
+ }
+ close $fh;
+ if (! defined($log_line)){
+ return; # incomplete
+ }
+ if ( $log_line =~ m/# Ended \(code 0\).*/ ) {
+ exit(0); # complete
+ } elsif ( $log_line =~ m/# Ended \(code \d+(; signal \d+)?\).*/ ){
+ if ($job_pick !~ m/^(failed|all)$/) {
+ exit(1); # failed but not going to run
+ } else {
+ return; # failed
+ }
+ } elsif ( $log_line =~ m/.*\S.*/ ) {
+ return; # incomplete jobs are always run
+ }
+}
+
+
+$logfile = shift @ARGV;
+
+if (defined $jobname && $logfile !~ m/$jobname/ &&
+ $jobend > $jobstart) {
+ print STDERR "run.pl: you are trying to run a parallel job but "
+ . "you are putting the output into just one log file ($logfile)\n";
+ exit(1);
+}
+
+$cmd = "";
+
+foreach $x (@ARGV) {
+ if ($x =~ m/^\S+$/) { $cmd .= $x . " "; }
+ elsif ($x =~ m:\":) { $cmd .= "'$x' "; }
+ else { $cmd .= "\"$x\" "; }
+}
+
+#$Data::Dumper::Indent=0;
+$ret = 0;
+$numfail = 0;
+%active_pids=();
+
+use POSIX ":sys_wait_h";
+for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
+ if (scalar(keys %active_pids) >= $max_jobs_run) {
+
+ # Lets wait for a change in any child's status
+ # Then we have to work out which child finished
+ $r = waitpid(-1, 0);
+ $code = $?;
+ if ($r < 0 ) { die "run.pl: Error waiting for child process"; } # should never happen.
+ if ( defined $active_pids{$r} ) {
+ $jid=$active_pids{$r};
+ $fail[$jid]=$code;
+ if ($code !=0) { $numfail++;}
+ delete $active_pids{$r};
+ # print STDERR "Finished: $r/$jid " . Dumper(\%active_pids) . "\n";
+ } else {
+ die "run.pl: Cannot find the PID of the child process that just finished.";
+ }
+
+ # In theory we could do a non-blocking waitpid over all jobs running just
+ # to find out if only one or more jobs finished during the previous waitpid()
+ # However, we just omit this and will reap the next one in the next pass
+ # through the for(;;) cycle
+ }
+ $childpid = fork();
+ if (!defined $childpid) { die "run.pl: Error forking in run.pl (writing to $logfile)"; }
+ if ($childpid == 0) { # We're in the child... this branch
+ # executes the job and returns (possibly with an error status).
+ if (defined $jobname) {
+ $cmd =~ s/$jobname/$jobid/g;
+ $logfile =~ s/$jobname/$jobid/g;
+ }
+ # exit if the job does not need to be executed
+ pick_or_exit( $logfile );
+
+ system("mkdir -p `dirname $logfile` 2>/dev/null");
+ open(F, ">$logfile") || die "run.pl: Error opening log file $logfile";
+ print F "# " . $cmd . "\n";
+ print F "# Started at " . `date`;
+ $starttime = `date +'%s'`;
+ print F "#\n";
+ close(F);
+
+ # Pipe into bash.. make sure we're not using any other shell.
+ open(B, "|bash") || die "run.pl: Error opening shell command";
+ print B "( " . $cmd . ") 2>>$logfile >> $logfile";
+ close(B); # If there was an error, exit status is in $?
+ $ret = $?;
+
+ $lowbits = $ret & 127;
+ $highbits = $ret >> 8;
+ if ($lowbits != 0) { $return_str = "code $highbits; signal $lowbits" }
+ else { $return_str = "code $highbits"; }
+
+ $endtime = `date +'%s'`;
+ open(F, ">>$logfile") || die "run.pl: Error opening log file $logfile (again)";
+ $enddate = `date`;
+ chop $enddate;
+ print F "# Accounting: time=" . ($endtime - $starttime) . " threads=1\n";
+ print F "# Ended ($return_str) at " . $enddate . ", elapsed time " . ($endtime-$starttime) . " seconds\n";
+ close(F);
+ exit($ret == 0 ? 0 : 1);
+ } else {
+ $pid[$jobid] = $childpid;
+ $active_pids{$childpid} = $jobid;
+ # print STDERR "Queued: " . Dumper(\%active_pids) . "\n";
+ }
+}
+
+# Now we have submitted all the jobs, lets wait until all the jobs finish
+foreach $child (keys %active_pids) {
+ $jobid=$active_pids{$child};
+ $r = waitpid($pid[$jobid], 0);
+ $code = $?;
+ if ($r == -1) { die "run.pl: Error waiting for child process"; } # should never happen.
+ if ($r != 0) { $fail[$jobid]=$code; $numfail++ if $code!=0; } # Completed successfully
+}
+
+# Some sanity checks:
+# The $fail array should not contain undefined codes
+# The number of non-zeros in that array should be equal to $numfail
+# We cannot do foreach() here, as the JOB ids do not start at zero
+$failed_jids=0;
+for ($jobid = $jobstart; $jobid <= $jobend; $jobid++) {
+ $job_return = $fail[$jobid];
+ if (not defined $job_return ) {
+ # print Dumper(\@fail);
+
+ die "run.pl: Sanity check failed: we have indication that some jobs are running " .
+ "even after we waited for all jobs to finish" ;
+ }
+ if ($job_return != 0 ){ $failed_jids++;}
+}
+if ($failed_jids != $numfail) {
+ die "run.pl: Sanity check failed: cannot find out how many jobs failed ($failed_jids x $numfail)."
+}
+if ($numfail > 0) { $ret = 1; }
+
+if ($ret != 0) {
+ $njobs = $jobend - $jobstart + 1;
+ if ($njobs == 1) {
+ if (defined $jobname) {
+ $logfile =~ s/$jobname/$jobstart/; # only one numbered job, so replace name with
+ # that job.
+ }
+ print STDERR "run.pl: job failed, log is in $logfile\n";
+ if ($logfile =~ m/JOB/) {
+ print STDERR "run.pl: probably you forgot to put JOB=1:\$nj in your script.";
+ }
+ }
+ else {
+ $logfile =~ s/$jobname/*/g;
+ print STDERR "run.pl: $numfail / $njobs failed, log is in $logfile\n";
+ }
+}
+
+
+exit ($ret);
diff --git a/egs/aishell2/transformer/utils/shuffle_list.pl b/egs/aishell2/transformer/utils/shuffle_list.pl
new file mode 100755
index 0000000..a116200
--- /dev/null
+++ b/egs/aishell2/transformer/utils/shuffle_list.pl
@@ -0,0 +1,44 @@
+#!/usr/bin/env perl
+
+# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+if ($ARGV[0] eq "--srand") {
+ $n = $ARGV[1];
+ $n =~ m/\d+/ || die "Bad argument to --srand option: \"$n\"";
+ srand($ARGV[1]);
+ shift;
+ shift;
+} else {
+ srand(0); # Gives inconsistent behavior if we don't seed.
+}
+
+if (@ARGV > 1 || $ARGV[0] =~ m/^-.+/) { # >1 args, or an option we
+ # don't understand.
+ print "Usage: shuffle_list.pl [--srand N] [input file] > output\n";
+ print "randomizes the order of lines of input.\n";
+ exit(1);
+}
+
+@lines;
+while (<>) {
+ push @lines, [ (rand(), $_)] ;
+}
+
+@lines = sort { $a->[0] cmp $b->[0] } @lines;
+foreach $l (@lines) {
+ print $l->[1];
+}
\ No newline at end of file
diff --git a/egs/aishell2/transformer/utils/split_data.py b/egs/aishell2/transformer/utils/split_data.py
new file mode 100755
index 0000000..060eae6
--- /dev/null
+++ b/egs/aishell2/transformer/utils/split_data.py
@@ -0,0 +1,60 @@
+import os
+import sys
+import random
+
+
+in_dir = sys.argv[1]
+out_dir = sys.argv[2]
+num_split = sys.argv[3]
+
+
+def split_scp(scp, num):
+ assert len(scp) >= num
+ avg = len(scp) // num
+ out = []
+ begin = 0
+
+ for i in range(num):
+ if i == num - 1:
+ out.append(scp[begin:])
+ else:
+ out.append(scp[begin:begin+avg])
+ begin += avg
+
+ return out
+
+
+os.path.exists("{}/wav.scp".format(in_dir))
+os.path.exists("{}/text".format(in_dir))
+
+with open("{}/wav.scp".format(in_dir), 'r') as infile:
+ wav_list = infile.readlines()
+
+with open("{}/text".format(in_dir), 'r') as infile:
+ text_list = infile.readlines()
+
+assert len(wav_list) == len(text_list)
+
+x = list(zip(wav_list, text_list))
+random.shuffle(x)
+wav_shuffle_list, text_shuffle_list = zip(*x)
+
+num_split = int(num_split)
+wav_split_list = split_scp(wav_shuffle_list, num_split)
+text_split_list = split_scp(text_shuffle_list, num_split)
+
+for idx, wav_list in enumerate(wav_split_list, 1):
+ path = out_dir + "/split" + str(num_split) + "/" + str(idx)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ with open("{}/wav.scp".format(path), 'w') as wav_writer:
+ for line in wav_list:
+ wav_writer.write(line)
+
+for idx, text_list in enumerate(text_split_list, 1):
+ path = out_dir + "/split" + str(num_split) + "/" + str(idx)
+ if not os.path.exists(path):
+ os.makedirs(path)
+ with open("{}/text".format(path), 'w') as text_writer:
+ for line in text_list:
+ text_writer.write(line)
diff --git a/egs/aishell2/transformer/utils/split_scp.pl b/egs/aishell2/transformer/utils/split_scp.pl
new file mode 100755
index 0000000..0876dcb
--- /dev/null
+++ b/egs/aishell2/transformer/utils/split_scp.pl
@@ -0,0 +1,246 @@
+#!/usr/bin/env perl
+
+# Copyright 2010-2011 Microsoft Corporation
+
+# See ../../COPYING for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# This program splits up any kind of .scp or archive-type file.
+# If there is no utt2spk option it will work on any text file and
+# will split it up with an approximately equal number of lines in
+# each but.
+# With the --utt2spk option it will work on anything that has the
+# utterance-id as the first entry on each line; the utt2spk file is
+# of the form "utterance speaker" (on each line).
+# It splits it into equal size chunks as far as it can. If you use the utt2spk
+# option it will make sure these chunks coincide with speaker boundaries. In
+# this case, if there are more chunks than speakers (and in some other
+# circumstances), some of the resulting chunks will be empty and it will print
+# an error message and exit with nonzero status.
+# You will normally call this like:
+# split_scp.pl scp scp.1 scp.2 scp.3 ...
+# or
+# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
+# Note that you can use this script to split the utt2spk file itself,
+# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
+
+# You can also call the scripts like:
+# split_scp.pl -j 3 0 scp scp.0
+# [note: with this option, it assumes zero-based indexing of the split parts,
+# i.e. the second number must be 0 <= n < num-jobs.]
+
+use warnings;
+
+$num_jobs = 0;
+$job_id = 0;
+$utt2spk_file = "";
+$one_based = 0;
+
+for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
+ if ($ARGV[0] eq "-j") {
+ shift @ARGV;
+ $num_jobs = shift @ARGV;
+ $job_id = shift @ARGV;
+ }
+ if ($ARGV[0] =~ /--utt2spk=(.+)/) {
+ $utt2spk_file=$1;
+ shift;
+ }
+ if ($ARGV[0] eq '--one-based') {
+ $one_based = 1;
+ shift @ARGV;
+ }
+}
+
+if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
+ $job_id - $one_based >= $num_jobs)) {
+ die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
+ ($one_based ? " --one-based" : "") . "'\n"
+}
+
+$one_based
+ and $job_id--;
+
+if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
+ die
+"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
+ or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
+ ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
+}
+
+$error = 0;
+$inscp = shift @ARGV;
+if ($num_jobs == 0) { # without -j option
+ @OUTPUTS = @ARGV;
+} else {
+ for ($j = 0; $j < $num_jobs; $j++) {
+ if ($j == $job_id) {
+ if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
+ else { push @OUTPUTS, "-"; }
+ } else {
+ push @OUTPUTS, "/dev/null";
+ }
+ }
+}
+
+if ($utt2spk_file ne "") { # We have the --utt2spk option...
+ open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
+ while(<$u_fh>) {
+ @A = split;
+ @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
+ ($u,$s) = @A;
+ $utt2spk{$u} = $s;
+ }
+ close $u_fh;
+ open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
+ @spkrs = ();
+ while(<$i_fh>) {
+ @A = split;
+ if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
+ $u = $A[0];
+ $s = $utt2spk{$u};
+ defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
+ if(!defined $spk_count{$s}) {
+ push @spkrs, $s;
+ $spk_count{$s} = 0;
+ $spk_data{$s} = []; # ref to new empty array.
+ }
+ $spk_count{$s}++;
+ push @{$spk_data{$s}}, $_;
+ }
+ # Now split as equally as possible ..
+ # First allocate spks to files by allocating an approximately
+ # equal number of speakers.
+ $numspks = @spkrs; # number of speakers.
+ $numscps = @OUTPUTS; # number of output files.
+ if ($numspks < $numscps) {
+ die "$0: Refusing to split data because number of speakers $numspks " .
+ "is less than the number of output .scp files $numscps\n";
+ }
+ for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+ $scparray[$scpidx] = []; # [] is array reference.
+ }
+ for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
+ $scpidx = int(($spkidx*$numscps) / $numspks);
+ $spk = $spkrs[$spkidx];
+ push @{$scparray[$scpidx]}, $spk;
+ $scpcount[$scpidx] += $spk_count{$spk};
+ }
+
+ # Now will try to reassign beginning + ending speakers
+ # to different scp's and see if it gets more balanced.
+ # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
+ # We can show that if considering changing just 2 scp's, we minimize
+ # this by minimizing the squared difference in sizes. This is
+ # equivalent to minimizing the absolute difference in sizes. This
+ # shows this method is bound to converge.
+
+ $changed = 1;
+ while($changed) {
+ $changed = 0;
+ for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+ # First try to reassign ending spk of this scp.
+ if($scpidx < $numscps-1) {
+ $sz = @{$scparray[$scpidx]};
+ if($sz > 0) {
+ $spk = $scparray[$scpidx]->[$sz-1];
+ $count = $spk_count{$spk};
+ $nutt1 = $scpcount[$scpidx];
+ $nutt2 = $scpcount[$scpidx+1];
+ if( abs( ($nutt2+$count) - ($nutt1-$count))
+ < abs($nutt2 - $nutt1)) { # Would decrease
+ # size-diff by reassigning spk...
+ $scpcount[$scpidx+1] += $count;
+ $scpcount[$scpidx] -= $count;
+ pop @{$scparray[$scpidx]};
+ unshift @{$scparray[$scpidx+1]}, $spk;
+ $changed = 1;
+ }
+ }
+ }
+ if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
+ $spk = $scparray[$scpidx]->[0];
+ $count = $spk_count{$spk};
+ $nutt1 = $scpcount[$scpidx-1];
+ $nutt2 = $scpcount[$scpidx];
+ if( abs( ($nutt2-$count) - ($nutt1+$count))
+ < abs($nutt2 - $nutt1)) { # Would decrease
+ # size-diff by reassigning spk...
+ $scpcount[$scpidx-1] += $count;
+ $scpcount[$scpidx] -= $count;
+ shift @{$scparray[$scpidx]};
+ push @{$scparray[$scpidx-1]}, $spk;
+ $changed = 1;
+ }
+ }
+ }
+ }
+ # Now print out the files...
+ for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+ $scpfile = $OUTPUTS[$scpidx];
+ ($scpfile ne '-' ? open($f_fh, '>', $scpfile)
+ : open($f_fh, '>&', \*STDOUT)) ||
+ die "$0: Could not open scp file $scpfile for writing: $!\n";
+ $count = 0;
+ if(@{$scparray[$scpidx]} == 0) {
+ print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
+ "$scpfile (too many splits and too few speakers?)\n";
+ $error = 1;
+ } else {
+ foreach $spk ( @{$scparray[$scpidx]} ) {
+ print $f_fh @{$spk_data{$spk}};
+ $count += $spk_count{$spk};
+ }
+ $count == $scpcount[$scpidx] || die "Count mismatch [code error]";
+ }
+ close($f_fh);
+ }
+} else {
+ # This block is the "normal" case where there is no --utt2spk
+ # option and we just break into equal size chunks.
+
+ open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
+
+ $numscps = @OUTPUTS; # size of array.
+ @F = ();
+ while(<$i_fh>) {
+ push @F, $_;
+ }
+ $numlines = @F;
+ if($numlines == 0) {
+ print STDERR "$0: error: empty input scp file $inscp\n";
+ $error = 1;
+ }
+ $linesperscp = int( $numlines / $numscps); # the "whole part"..
+ $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
+ $remainder = $numlines - ($linesperscp * $numscps);
+ ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
+ # [just doing int() rounds down].
+ $n = 0;
+ for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
+ $scpfile = $OUTPUTS[$scpidx];
+ ($scpfile ne '-' ? open($o_fh, '>', $scpfile)
+ : open($o_fh, '>&', \*STDOUT)) ||
+ die "$0: Could not open scp file $scpfile for writing: $!\n";
+ for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
+ print $o_fh $F[$n++];
+ }
+ close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
+ }
+ $n == $numlines || die "$n != $numlines [code error]";
+}
+
+exit ($error);
diff --git a/egs/aishell2/transformer/utils/subset_data_dir_tr_cv.sh b/egs/aishell2/transformer/utils/subset_data_dir_tr_cv.sh
new file mode 100755
index 0000000..e16cebd
--- /dev/null
+++ b/egs/aishell2/transformer/utils/subset_data_dir_tr_cv.sh
@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+
+dev_num_utt=1000
+
+echo "$0 $@"
+. utils/parse_options.sh || exit 1;
+
+train_data=$1
+out_dir=$2
+
+[ ! -f ${train_data}/wav.scp ] && echo "$0: no such file ${train_data}/wav.scp" && exit 1;
+[ ! -f ${train_data}/text ] && echo "$0: no such file ${train_data}/text" && exit 1;
+
+mkdir -p ${out_dir}/train && mkdir -p ${out_dir}/dev
+
+cp ${train_data}/wav.scp ${out_dir}/train/wav.scp.bak
+cp ${train_data}/text ${out_dir}/train/text.bak
+
+num_utt=$(wc -l <${out_dir}/train/wav.scp.bak)
+
+utils/shuffle_list.pl --srand 1 ${out_dir}/train/wav.scp.bak > ${out_dir}/train/wav.scp.shuf
+head -n ${dev_num_utt} ${out_dir}/train/wav.scp.shuf > ${out_dir}/dev/wav.scp
+tail -n $((${num_utt}-${dev_num_utt})) ${out_dir}/train/wav.scp.shuf > ${out_dir}/train/wav.scp
+
+utils/shuffle_list.pl --srand 1 ${out_dir}/train/text.bak > ${out_dir}/train/text.shuf
+head -n ${dev_num_utt} ${out_dir}/train/text.shuf > ${out_dir}/dev/text
+tail -n $((${num_utt}-${dev_num_utt})) ${out_dir}/train/text.shuf > ${out_dir}/train/text
+
+rm ${out_dir}/train/wav.scp.bak ${out_dir}/train/text.bak
+rm ${out_dir}/train/wav.scp.shuf ${out_dir}/train/text.shuf
diff --git a/egs/aishell2/transformer/utils/text2token.py b/egs/aishell2/transformer/utils/text2token.py
new file mode 100755
index 0000000..56c3913
--- /dev/null
+++ b/egs/aishell2/transformer/utils/text2token.py
@@ -0,0 +1,135 @@
+#!/usr/bin/env python3
+
+# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+
+import argparse
+import codecs
+import re
+import sys
+
+is_python2 = sys.version_info[0] == 2
+
+
+def exist_or_not(i, match_pos):
+ start_pos = None
+ end_pos = None
+ for pos in match_pos:
+ if pos[0] <= i < pos[1]:
+ start_pos = pos[0]
+ end_pos = pos[1]
+ break
+
+ return start_pos, end_pos
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="convert raw text to tokenized text",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--nchar",
+ "-n",
+ default=1,
+ type=int,
+ help="number of characters to split, i.e., \
+ aabb -> a a b b with -n 1 and aa bb with -n 2",
+ )
+ parser.add_argument(
+ "--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
+ )
+ parser.add_argument("--space", default="<space>", type=str, help="space symbol")
+ parser.add_argument(
+ "--non-lang-syms",
+ "-l",
+ default=None,
+ type=str,
+ help="list of non-linguistic symobles, e.g., <NOISE> etc.",
+ )
+ parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+ parser.add_argument(
+ "--trans_type",
+ "-t",
+ type=str,
+ default="char",
+ choices=["char", "phn"],
+ help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
+ If trans_type is char,
+ read from SI1279.WRD file -> "bricks are an alternative"
+ Else if trans_type is phn,
+ read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
+ sil t er n ih sil t ih v sil" """,
+ )
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ rs = []
+ if args.non_lang_syms is not None:
+ with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
+ nls = [x.rstrip() for x in f.readlines()]
+ rs = [re.compile(re.escape(x)) for x in nls]
+
+ if args.text:
+ f = codecs.open(args.text, encoding="utf-8")
+ else:
+ f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+
+ sys.stdout = codecs.getwriter("utf-8")(
+ sys.stdout if is_python2 else sys.stdout.buffer
+ )
+ line = f.readline()
+ n = args.nchar
+ while line:
+ x = line.split()
+ print(" ".join(x[: args.skip_ncols]), end=" ")
+ a = " ".join(x[args.skip_ncols :])
+
+ # get all matched positions
+ match_pos = []
+ for r in rs:
+ i = 0
+ while i >= 0:
+ m = r.search(a, i)
+ if m:
+ match_pos.append([m.start(), m.end()])
+ i = m.end()
+ else:
+ break
+
+ if args.trans_type == "phn":
+ a = a.split(" ")
+ else:
+ if len(match_pos) > 0:
+ chars = []
+ i = 0
+ while i < len(a):
+ start_pos, end_pos = exist_or_not(i, match_pos)
+ if start_pos is not None:
+ chars.append(a[start_pos:end_pos])
+ i = end_pos
+ else:
+ chars.append(a[i])
+ i += 1
+ a = chars
+
+ a = [a[j : j + n] for j in range(0, len(a), n)]
+
+ a_flat = []
+ for z in a:
+ a_flat.append("".join(z))
+
+ a_chars = [z.replace(" ", args.space) for z in a_flat]
+ if args.trans_type == "phn":
+ a_chars = [z.replace("sil", args.space) for z in a_chars]
+ print(" ".join(a_chars))
+ line = f.readline()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell2/transformer/utils/text_tokenize.py b/egs/aishell2/transformer/utils/text_tokenize.py
new file mode 100755
index 0000000..962ea11
--- /dev/null
+++ b/egs/aishell2/transformer/utils/text_tokenize.py
@@ -0,0 +1,106 @@
+import re
+import argparse
+
+
+def load_dict(seg_file):
+ seg_dict = {}
+ with open(seg_file, 'r') as infile:
+ for line in infile:
+ s = line.strip().split()
+ key = s[0]
+ value = s[1:]
+ seg_dict[key] = " ".join(value)
+ return seg_dict
+
+
+def forward_segment(text, dic):
+ word_list = []
+ i = 0
+ while i < len(text):
+ longest_word = text[i]
+ for j in range(i + 1, len(text) + 1):
+ word = text[i:j]
+ if word in dic:
+ if len(word) > len(longest_word):
+ longest_word = word
+ word_list.append(longest_word)
+ i += len(longest_word)
+ return word_list
+
+
+def tokenize(txt,
+ seg_dict):
+ out_txt = ""
+ pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
+ for word in txt:
+ if pattern.match(word):
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ out_txt += "<unk>" + " "
+ else:
+ continue
+ return out_txt.strip()
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="text tokenize",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--text-file",
+ "-t",
+ default=False,
+ required=True,
+ type=str,
+ help="input text",
+ )
+ parser.add_argument(
+ "--seg-file",
+ "-s",
+ default=False,
+ required=True,
+ type=str,
+ help="seg file",
+ )
+ parser.add_argument(
+ "--txt-index",
+ "-i",
+ default=1,
+ required=True,
+ type=int,
+ help="txt index",
+ )
+ parser.add_argument(
+ "--output-dir",
+ "-o",
+ default=False,
+ required=True,
+ type=str,
+ help="output dir",
+ )
+ return parser
+
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), 'w')
+ shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), 'w')
+ seg_dict = load_dict(args.seg_file)
+ with open(args.text_file, 'r') as infile:
+ for line in infile:
+ s = line.strip().split()
+ text_id = s[0]
+ text_list = forward_segment("".join(s[1:]).lower(), seg_dict)
+ text = tokenize(text_list, seg_dict)
+ lens = len(text.strip().split())
+ txt_writer.write(text_id + " " + text + '\n')
+ shape_writer.write(text_id + " " + str(lens) + '\n')
+
+
+if __name__ == '__main__':
+ main()
+
diff --git a/egs/aishell2/transformer/utils/text_tokenize.sh b/egs/aishell2/transformer/utils/text_tokenize.sh
new file mode 100755
index 0000000..6b74fef
--- /dev/null
+++ b/egs/aishell2/transformer/utils/text_tokenize.sh
@@ -0,0 +1,35 @@
+#!/usr/bin/env bash
+
+
+# Begin configuration section.
+nj=32
+cmd=utils/run.pl
+
+echo "$0 $@"
+
+. utils/parse_options.sh || exit 1;
+
+# tokenize configuration
+text_dir=$1
+seg_file=$2
+logdir=$3
+output_dir=$4
+
+txt_dir=${output_dir}/txt; mkdir -p ${output_dir}/txt
+mkdir -p ${logdir}
+
+$cmd JOB=1:$nj $logdir/text_tokenize.JOB.log \
+ python utils/text_tokenize.py -t ${text_dir}/txt/text.JOB.txt \
+ -s ${seg_file} -i JOB -o ${txt_dir} \
+ || exit 1;
+
+# concatenate the text files together.
+for n in $(seq $nj); do
+ cat ${txt_dir}/text.$n.txt || exit 1
+done > ${output_dir}/text || exit 1
+
+for n in $(seq $nj); do
+ cat ${txt_dir}/len.$n || exit 1
+done > ${output_dir}/text_shape || exit 1
+
+echo "$0: Succeeded text tokenize"
diff --git a/egs/aishell2/transformer/utils/textnorm_zh.py b/egs/aishell2/transformer/utils/textnorm_zh.py
new file mode 100755
index 0000000..79feb83
--- /dev/null
+++ b/egs/aishell2/transformer/utils/textnorm_zh.py
@@ -0,0 +1,834 @@
+#!/usr/bin/env python3
+# coding=utf-8
+
+# Authors:
+# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
+# 2019.9 Jiayu DU
+#
+# requirements:
+# - python 3.X
+# notes: python 2.X WILL fail or produce misleading results
+
+import sys, os, argparse, codecs, string, re
+
+# ================================================================================ #
+# basic constant
+# ================================================================================ #
+CHINESE_DIGIS = u'闆朵竴浜屼笁鍥涗簲鍏竷鍏節'
+BIG_CHINESE_DIGIS_SIMPLIFIED = u'闆跺9璐板弫鑲嗕紞闄嗘煉鎹岀帠'
+BIG_CHINESE_DIGIS_TRADITIONAL = u'闆跺9璨冲弮鑲嗕紞闄告煉鎹岀帠'
+SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'鍗佺櫨鍗冧竾'
+SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'鎷句桨浠熻惉'
+LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'浜垮厗浜灀绉┌娌熸锭姝h浇'
+LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'鍎勫厗浜灀绉┌婧濇緱姝h級'
+SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'鍗佺櫨鍗冧竾'
+SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'鎷句桨浠熻惉'
+
+ZERO_ALT = u'銆�'
+ONE_ALT = u'骞�'
+TWO_ALTS = [u'涓�', u'鍏�']
+
+POSITIVE = [u'姝�', u'姝�']
+NEGATIVE = [u'璐�', u'璨�']
+POINT = [u'鐐�', u'榛�']
+# PLUS = [u'鍔�', u'鍔�']
+# SIL = [u'鏉�', u'妲�']
+
+FILLER_CHARS = ['鍛�', '鍟�']
+ER_WHITELIST = '(鍎垮コ|鍎垮瓙|鍎垮瓩|濂冲効|鍎垮|濡诲効|' \
+ '鑳庡効|濠村効|鏂扮敓鍎縷濠村辜鍎縷骞煎効|灏戝効|灏忓効|鍎挎瓕|鍎跨|鍎跨|鎵樺効鎵�|瀛ゅ効|' \
+ '鍎挎垙|鍎垮寲|鍙板効搴剕楣垮効宀泑姝e効鍏粡|鍚婂効閮庡綋|鐢熷効鑲插コ|鎵樺効甯﹀コ|鍏诲効闃茶�亅鐥村効鍛嗗コ|' \
+ '浣冲効浣冲|鍎挎�滃吔鎵皘鍎挎棤甯哥埗|鍎夸笉瀚屾瘝涓憒鍎胯鍗冮噷姣嶆媴蹇鍎垮ぇ涓嶇敱鐖穦鑻忎篂鍎�)'
+
+# 涓枃鏁板瓧绯荤粺绫诲瀷
+NUMBERING_TYPES = ['low', 'mid', 'high']
+
+CURRENCY_NAMES = '(浜烘皯甯亅缇庡厓|鏃ュ厓|鑻遍晳|娆у厓|椹厠|娉曢儙|鍔犳嬁澶у厓|婢冲厓|娓竵|鍏堜护|鑺叞椹厠|鐖卞皵鍏伴晳|' \
+ '閲屾媺|鑽峰叞鐩緗鍩冩柉搴撳|姣斿濉攟鍗板凹鐩緗鏋楀悏鐗箌鏂拌タ鍏板厓|姣旂储|鍗㈠竷|鏂板姞鍧″厓|闊╁厓|娉伴摙)'
+CURRENCY_UNITS = '((浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧�)|(浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧緗)鍏億(浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧緗)鍧梶瑙抾姣泑鍒�)'
+COM_QUANTIFIERS = '(鍖箌寮爘搴鍥瀨鍦簗灏緗鏉涓獆棣東闃檤闃祙缃憒鐐畖椤秥涓榺妫祙鍙獆鏀瘄琚瓅杈唡鎸憒鎷厊棰梶澹硘绐爘鏇瞸澧檤缇鑵攟' \
+ '鐮搴瀹璐瘄鎵巪鎹唡鍒�|浠鎵搢鎵媩缃梶鍧灞眧宀瓅姹焲婧獆閽焲闃焲鍗晐鍙寍瀵箌鍑簗鍙澶磡鑴殀鏉縷璺硘鏋潀浠秥璐磡' \
+ '閽坾绾縷绠鍚峾浣峾韬珅鍫倈璇緗鏈瑋椤祙瀹秥鎴穦灞倈涓潀姣珅鍘榺鍒唡閽眧涓鏂鎷厊閾鐭硘閽閿眧蹇絴(鍗億姣珅寰�)鍏媩' \
+ '姣珅鍘榺鍒唡瀵竱灏簗涓坾閲寍瀵粅甯竱閾簗绋媩(鍗億鍒唡鍘榺姣珅寰�)绫硘鎾畖鍕簗鍚坾鍗噟鏂梶鐭硘鐩榺纰梶纰焲鍙爘妗秥绗紎鐩唡' \
+ '鐩抾鏉瘄閽焲鏂泑閿厊绨媩绡畖鐩榺妗秥缃恷鐡秥澹秥鍗畖鐩弢绠﹟绠眧鐓瞸鍟東琚媩閽祙骞磡鏈坾鏃瀛鍒粅鏃秥鍛▅澶﹟绉抾鍒唡鏃瑋' \
+ '绾獆宀亅涓東鏇磡澶渱鏄澶弢绉媩鍐瑋浠浼弢杈坾涓竱娉绮抾棰梶骞鍫唡鏉鏍箌鏀瘄閬搢闈鐗噟寮爘棰梶鍧�)'
+
+# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
+CHINESE_PUNC_STOP = '锛侊紵锝°��'
+CHINESE_PUNC_NON_STOP = '锛傦純锛勶紖锛嗭紘锛堬級锛婏紜锛岋紞锛忥細锛涳紲锛濓紴锛狅蓟锛硷冀锛撅伎锝�锝涳綔锝濓綖锝燂綘锝剑锝ゃ�併�冦�嬨�屻�嶃�庛�忋�愩�戙�斻�曘�栥�椼�樸�欍�氥�涖�溿�濄�炪�熴�般�俱�库�撯�斺�樷�欌�涒�溾�濃�炩�熲�︹�э箯'
+CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
+
+# ================================================================================ #
+# basic class
+# ================================================================================ #
+class ChineseChar(object):
+ """
+ 涓枃瀛楃
+ 姣忎釜瀛楃瀵瑰簲绠�浣撳拰绻佷綋,
+ e.g. 绠�浣� = '璐�', 绻佷綋 = '璨�'
+ 杞崲鏃跺彲杞崲涓虹畝浣撴垨绻佷綋
+ """
+
+ def __init__(self, simplified, traditional):
+ self.simplified = simplified
+ self.traditional = traditional
+ #self.__repr__ = self.__str__
+
+ def __str__(self):
+ return self.simplified or self.traditional or None
+
+ def __repr__(self):
+ return self.__str__()
+
+
+class ChineseNumberUnit(ChineseChar):
+ """
+ 涓枃鏁板瓧/鏁颁綅瀛楃
+ 姣忎釜瀛楃闄ょ箒绠�浣撳杩樻湁涓�涓澶栫殑澶у啓瀛楃
+ e.g. '闄�' 鍜� '闄�'
+ """
+
+ def __init__(self, power, simplified, traditional, big_s, big_t):
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
+ self.power = power
+ self.big_s = big_s
+ self.big_t = big_t
+
+ def __str__(self):
+ return '10^{}'.format(self.power)
+
+ @classmethod
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
+
+ if small_unit:
+ return ChineseNumberUnit(power=index + 1,
+ simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
+ elif numbering_type == NUMBERING_TYPES[0]:
+ return ChineseNumberUnit(power=index + 8,
+ simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+ elif numbering_type == NUMBERING_TYPES[1]:
+ return ChineseNumberUnit(power=(index + 2) * 4,
+ simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+ elif numbering_type == NUMBERING_TYPES[2]:
+ return ChineseNumberUnit(power=pow(2, index + 3),
+ simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+ else:
+ raise ValueError(
+ 'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
+
+
+class ChineseNumberDigit(ChineseChar):
+ """
+ 涓枃鏁板瓧瀛楃
+ """
+
+ def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
+ self.value = value
+ self.big_s = big_s
+ self.big_t = big_t
+ self.alt_s = alt_s
+ self.alt_t = alt_t
+
+ def __str__(self):
+ return str(self.value)
+
+ @classmethod
+ def create(cls, i, v):
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
+
+
+class ChineseMath(ChineseChar):
+ """
+ 涓枃鏁颁綅瀛楃
+ """
+
+ def __init__(self, simplified, traditional, symbol, expression=None):
+ super(ChineseMath, self).__init__(simplified, traditional)
+ self.symbol = symbol
+ self.expression = expression
+ self.big_s = simplified
+ self.big_t = traditional
+
+
+CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
+
+
+class NumberSystem(object):
+ """
+ 涓枃鏁板瓧绯荤粺
+ """
+ pass
+
+
+class MathSymbol(object):
+ """
+ 鐢ㄤ簬涓枃鏁板瓧绯荤粺鐨勬暟瀛︾鍙� (绻�/绠�浣�), e.g.
+ positive = ['姝�', '姝�']
+ negative = ['璐�', '璨�']
+ point = ['鐐�', '榛�']
+ """
+
+ def __init__(self, positive, negative, point):
+ self.positive = positive
+ self.negative = negative
+ self.point = point
+
+ def __iter__(self):
+ for v in self.__dict__.values():
+ yield v
+
+
+# class OtherSymbol(object):
+# """
+# 鍏朵粬绗﹀彿
+# """
+#
+# def __init__(self, sil):
+# self.sil = sil
+#
+# def __iter__(self):
+# for v in self.__dict__.values():
+# yield v
+
+
+# ================================================================================ #
+# basic utils
+# ================================================================================ #
+def create_system(numbering_type=NUMBERING_TYPES[1]):
+ """
+ 鏍规嵁鏁板瓧绯荤粺绫诲瀷杩斿洖鍒涘缓鐩稿簲鐨勬暟瀛楃郴缁燂紝榛樿涓� mid
+ NUMBERING_TYPES = ['low', 'mid', 'high']: 涓枃鏁板瓧绯荤粺绫诲瀷
+ low: '鍏�' = '浜�' * '鍗�' = $10^{9}$, '浜�' = '鍏�' * '鍗�', etc.
+ mid: '鍏�' = '浜�' * '涓�' = $10^{12}$, '浜�' = '鍏�' * '涓�', etc.
+ high: '鍏�' = '浜�' * '浜�' = $10^{16}$, '浜�' = '鍏�' * '鍏�', etc.
+ 杩斿洖瀵瑰簲鐨勬暟瀛楃郴缁�
+ """
+
+ # chinese number units of '浜�' and larger
+ all_larger_units = zip(
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
+ larger_units = [CNU.create(i, v, numbering_type, False)
+ for i, v in enumerate(all_larger_units)]
+ # chinese number units of '鍗�, 鐧�, 鍗�, 涓�'
+ all_smaller_units = zip(
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
+ smaller_units = [CNU.create(i, v, small_unit=True)
+ for i, v in enumerate(all_smaller_units)]
+ # digis
+ chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
+ BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
+ digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
+ digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
+ digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
+ digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
+
+ # symbols
+ positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
+ negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
+ point_cn = CM(POINT[0], POINT[1], '.', lambda x,
+ y: float(str(x) + '.' + str(y)))
+ # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
+ system = NumberSystem()
+ system.units = smaller_units + larger_units
+ system.digits = digits
+ system.math = MathSymbol(positive_cn, negative_cn, point_cn)
+ # system.symbols = OtherSymbol(sil_cn)
+ return system
+
+
+def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
+
+ def get_symbol(char, system):
+ for u in system.units:
+ if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
+ return u
+ for d in system.digits:
+ if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
+ return d
+ for m in system.math:
+ if char in [m.traditional, m.simplified]:
+ return m
+
+ def string2symbols(chinese_string, system):
+ int_string, dec_string = chinese_string, ''
+ for p in [system.math.point.simplified, system.math.point.traditional]:
+ if p in chinese_string:
+ int_string, dec_string = chinese_string.split(p)
+ break
+ return [get_symbol(c, system) for c in int_string], \
+ [get_symbol(c, system) for c in dec_string]
+
+ def correct_symbols(integer_symbols, system):
+ """
+ 涓�鐧惧叓 to 涓�鐧惧叓鍗�
+ 涓�浜夸竴鍗冧笁鐧句竾 to 涓�浜� 涓�鍗冧竾 涓夌櫨涓�
+ """
+
+ if integer_symbols and isinstance(integer_symbols[0], CNU):
+ if integer_symbols[0].power == 1:
+ integer_symbols = [system.digits[1]] + integer_symbols
+
+ if len(integer_symbols) > 1:
+ if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
+ integer_symbols.append(
+ CNU(integer_symbols[-2].power - 1, None, None, None, None))
+
+ result = []
+ unit_count = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ result.append(s)
+ unit_count = 0
+ elif isinstance(s, CNU):
+ current_unit = CNU(s.power, None, None, None, None)
+ unit_count += 1
+
+ if unit_count == 1:
+ result.append(current_unit)
+ elif unit_count > 1:
+ for i in range(len(result)):
+ if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
+ result[-i - 1] = CNU(result[-i - 1].power +
+ current_unit.power, None, None, None, None)
+ return result
+
+ def compute_value(integer_symbols):
+ """
+ Compute the value.
+ When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
+ e.g. '涓ゅ崈涓�' = 2000 * 10000 not 2000 + 10000
+ """
+ value = [0]
+ last_power = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ value[-1] = s.value
+ elif isinstance(s, CNU):
+ value[-1] *= pow(10, s.power)
+ if s.power > last_power:
+ value[:-1] = list(map(lambda v: v *
+ pow(10, s.power), value[:-1]))
+ last_power = s.power
+ value.append(0)
+ return sum(value)
+
+ system = create_system(numbering_type)
+ int_part, dec_part = string2symbols(chinese_string, system)
+ int_part = correct_symbols(int_part, system)
+ int_str = str(compute_value(int_part))
+ dec_str = ''.join([str(d.value) for d in dec_part])
+ if dec_part:
+ return '{0}.{1}'.format(int_str, dec_str)
+ else:
+ return int_str
+
+
+def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
+ traditional=False, alt_zero=False, alt_one=False, alt_two=True,
+ use_zeros=True, use_units=True):
+
+ def get_value(value_string, use_zeros=True):
+
+ striped_string = value_string.lstrip('0')
+
+ # record nothing if all zeros
+ if not striped_string:
+ return []
+
+ # record one digits
+ elif len(striped_string) == 1:
+ if use_zeros and len(value_string) != len(striped_string):
+ return [system.digits[0], system.digits[int(striped_string)]]
+ else:
+ return [system.digits[int(striped_string)]]
+
+ # recursively record multiple digits
+ else:
+ result_unit = next(u for u in reversed(
+ system.units) if u.power < len(striped_string))
+ result_string = value_string[:-result_unit.power]
+ return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
+
+ system = create_system(numbering_type)
+
+ int_dec = number_string.split('.')
+ if len(int_dec) == 1:
+ int_string = int_dec[0]
+ dec_string = ""
+ elif len(int_dec) == 2:
+ int_string = int_dec[0]
+ dec_string = int_dec[1]
+ else:
+ raise ValueError(
+ "invalid input num string with more than one dot: {}".format(number_string))
+
+ if use_units and len(int_string) > 1:
+ result_symbols = get_value(int_string)
+ else:
+ result_symbols = [system.digits[int(c)] for c in int_string]
+ dec_symbols = [system.digits[int(c)] for c in dec_string]
+ if dec_string:
+ result_symbols += [system.math.point] + dec_symbols
+
+ if alt_two:
+ liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
+ system.digits[2].big_s, system.digits[2].big_t)
+ for i, v in enumerate(result_symbols):
+ if isinstance(v, CND) and v.value == 2:
+ next_symbol = result_symbols[i +
+ 1] if i < len(result_symbols) - 1 else None
+ previous_symbol = result_symbols[i - 1] if i > 0 else None
+ if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
+ if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
+ result_symbols[i] = liang
+
+ # if big is True, '涓�' will not be used and `alt_two` has no impact on output
+ if big:
+ attr_name = 'big_'
+ if traditional:
+ attr_name += 't'
+ else:
+ attr_name += 's'
+ else:
+ if traditional:
+ attr_name = 'traditional'
+ else:
+ attr_name = 'simplified'
+
+ result = ''.join([getattr(s, attr_name) for s in result_symbols])
+
+ # if not use_zeros:
+ # result = result.strip(getattr(system.digits[0], attr_name))
+
+ if alt_zero:
+ result = result.replace(
+ getattr(system.digits[0], attr_name), system.digits[0].alt_s)
+
+ if alt_one:
+ result = result.replace(
+ getattr(system.digits[1], attr_name), system.digits[1].alt_s)
+
+ for i, p in enumerate(POINT):
+ if result.startswith(p):
+ return CHINESE_DIGIS[0] + result
+
+ # ^10, 11, .., 19
+ if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
+ result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
+ result = result[1:]
+
+ return result
+
+
+# ================================================================================ #
+# different types of rewriters
+# ================================================================================ #
+class Cardinal:
+ """
+ CARDINAL绫�
+ """
+
+ def __init__(self, cardinal=None, chntext=None):
+ self.cardinal = cardinal
+ self.chntext = chntext
+
+ def chntext2cardinal(self):
+ return chn2num(self.chntext)
+
+ def cardinal2chntext(self):
+ return num2chn(self.cardinal)
+
+class Digit:
+ """
+ DIGIT绫�
+ """
+
+ def __init__(self, digit=None, chntext=None):
+ self.digit = digit
+ self.chntext = chntext
+
+ # def chntext2digit(self):
+ # return chn2num(self.chntext)
+
+ def digit2chntext(self):
+ return num2chn(self.digit, alt_two=False, use_units=False)
+
+
+class TelePhone:
+ """
+ TELEPHONE绫�
+ """
+
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
+ self.telephone = telephone
+ self.raw_chntext = raw_chntext
+ self.chntext = chntext
+
+ # def chntext2telephone(self):
+ # sil_parts = self.raw_chntext.split('<SIL>')
+ # self.telephone = '-'.join([
+ # str(chn2num(p)) for p in sil_parts
+ # ])
+ # return self.telephone
+
+ def telephone2chntext(self, fixed=False):
+
+ if fixed:
+ sil_parts = self.telephone.split('-')
+ self.raw_chntext = '<SIL>'.join([
+ num2chn(part, alt_two=False, use_units=False) for part in sil_parts
+ ])
+ self.chntext = self.raw_chntext.replace('<SIL>', '')
+ else:
+ sp_parts = self.telephone.strip('+').split()
+ self.raw_chntext = '<SP>'.join([
+ num2chn(part, alt_two=False, use_units=False) for part in sp_parts
+ ])
+ self.chntext = self.raw_chntext.replace('<SP>', '')
+ return self.chntext
+
+
+class Fraction:
+ """
+ FRACTION绫�
+ """
+
+ def __init__(self, fraction=None, chntext=None):
+ self.fraction = fraction
+ self.chntext = chntext
+
+ def chntext2fraction(self):
+ denominator, numerator = self.chntext.split('鍒嗕箣')
+ return chn2num(numerator) + '/' + chn2num(denominator)
+
+ def fraction2chntext(self):
+ numerator, denominator = self.fraction.split('/')
+ return num2chn(denominator) + '鍒嗕箣' + num2chn(numerator)
+
+
+class Date:
+ """
+ DATE绫�
+ """
+
+ def __init__(self, date=None, chntext=None):
+ self.date = date
+ self.chntext = chntext
+
+ # def chntext2date(self):
+ # chntext = self.chntext
+ # try:
+ # year, other = chntext.strip().split('骞�', maxsplit=1)
+ # year = Digit(chntext=year).digit2chntext() + '骞�'
+ # except ValueError:
+ # other = chntext
+ # year = ''
+ # if other:
+ # try:
+ # month, day = other.strip().split('鏈�', maxsplit=1)
+ # month = Cardinal(chntext=month).chntext2cardinal() + '鏈�'
+ # except ValueError:
+ # day = chntext
+ # month = ''
+ # if day:
+ # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
+ # else:
+ # month = ''
+ # day = ''
+ # date = year + month + day
+ # self.date = date
+ # return self.date
+
+ def date2chntext(self):
+ date = self.date
+ try:
+ year, other = date.strip().split('骞�', 1)
+ year = Digit(digit=year).digit2chntext() + '骞�'
+ except ValueError:
+ other = date
+ year = ''
+ if other:
+ try:
+ month, day = other.strip().split('鏈�', 1)
+ month = Cardinal(cardinal=month).cardinal2chntext() + '鏈�'
+ except ValueError:
+ day = date
+ month = ''
+ if day:
+ day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
+ else:
+ month = ''
+ day = ''
+ chntext = year + month + day
+ self.chntext = chntext
+ return self.chntext
+
+
+class Money:
+ """
+ MONEY绫�
+ """
+
+ def __init__(self, money=None, chntext=None):
+ self.money = money
+ self.chntext = chntext
+
+ # def chntext2money(self):
+ # return self.money
+
+ def money2chntext(self):
+ money = self.money
+ pattern = re.compile(r'(\d+(\.\d+)?)')
+ matchers = pattern.findall(money)
+ if matchers:
+ for matcher in matchers:
+ money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
+ self.chntext = money
+ return self.chntext
+
+
+class Percentage:
+ """
+ PERCENTAGE绫�
+ """
+
+ def __init__(self, percentage=None, chntext=None):
+ self.percentage = percentage
+ self.chntext = chntext
+
+ def chntext2percentage(self):
+ return chn2num(self.chntext.strip().strip('鐧惧垎涔�')) + '%'
+
+ def percentage2chntext(self):
+ return '鐧惧垎涔�' + num2chn(self.percentage.strip().strip('%'))
+
+
+def remove_erhua(text, er_whitelist):
+ """
+ 鍘婚櫎鍎垮寲闊宠瘝涓殑鍎�:
+ 浠栧コ鍎垮湪閭h竟鍎� -> 浠栧コ鍎垮湪閭h竟
+ """
+
+ er_pattern = re.compile(er_whitelist)
+ new_str=''
+ while re.search('鍎�',text):
+ a = re.search('鍎�',text).span()
+ remove_er_flag = 0
+
+ if er_pattern.search(text):
+ b = er_pattern.search(text).span()
+ if b[0] <= a[0]:
+ remove_er_flag = 1
+
+ if remove_er_flag == 0 :
+ new_str = new_str + text[0:a[0]]
+ text = text[a[1]:]
+ else:
+ new_str = new_str + text[0:b[1]]
+ text = text[b[1]:]
+
+ text = new_str + text
+ return text
+
+# ================================================================================ #
+# NSW Normalizer
+# ================================================================================ #
+class NSWNormalizer:
+ def __init__(self, raw_text):
+ self.raw_text = '^' + raw_text + '$'
+ self.norm_text = ''
+
+ def _particular(self):
+ text = self.norm_text
+ pattern = re.compile(r"(([a-zA-Z]+)浜�([a-zA-Z]+))")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('particular')
+ for matcher in matchers:
+ text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1)
+ self.norm_text = text
+ return self.norm_text
+
+ def normalize(self):
+ text = self.raw_text
+
+ # 瑙勮寖鍖栨棩鏈�
+ pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})骞�)?(\d{1,2}鏈�(\d{1,2}[鏃ュ彿])?)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ #print('date')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
+
+ # 瑙勮寖鍖栭噾閽�
+ pattern = re.compile(r"\D+((\d+(\.\d+)?)[澶氫綑鍑燷?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ #print('money')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
+
+ # 瑙勮寖鍖栧浐璇�/鎵嬫満鍙风爜
+ # 鎵嬫満
+ # http://www.jihaoba.com/news/show/13680
+ # 绉诲姩锛�139銆�138銆�137銆�136銆�135銆�134銆�159銆�158銆�157銆�150銆�151銆�152銆�188銆�187銆�182銆�183銆�184銆�178銆�198
+ # 鑱旈�氾細130銆�131銆�132銆�156銆�155銆�186銆�185銆�176
+ # 鐢典俊锛�133銆�153銆�189銆�180銆�181銆�177
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ #print('telephone')
+ for matcher in matchers:
+ text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
+ # 鍥鸿瘽
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fixed telephone')
+ for matcher in matchers:
+ text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
+
+ # 瑙勮寖鍖栧垎鏁�
+ pattern = re.compile(r"(\d+/\d+)")
+ matchers = pattern.findall(text)
+ if matchers:
+ #print('fraction')
+ for matcher in matchers:
+ text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
+
+ # 瑙勮寖鍖栫櫨鍒嗘暟
+ text = text.replace('锛�', '%')
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
+ matchers = pattern.findall(text)
+ if matchers:
+ #print('percentage')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
+
+ # 瑙勮寖鍖栫函鏁�+閲忚瘝
+ pattern = re.compile(r"(\d+(\.\d+)?)[澶氫綑鍑燷?" + COM_QUANTIFIERS)
+ matchers = pattern.findall(text)
+ if matchers:
+ #print('cardinal+quantifier')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+ # 瑙勮寖鍖栨暟瀛楃紪鍙�
+ pattern = re.compile(r"(\d{4,32})")
+ matchers = pattern.findall(text)
+ if matchers:
+ #print('digit')
+ for matcher in matchers:
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
+
+ # 瑙勮寖鍖栫函鏁�
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ #print('cardinal')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+ self.norm_text = text
+ self._particular()
+
+ return self.norm_text.lstrip('^').rstrip('$')
+
+
+def nsw_test_case(raw_text):
+ print('I:' + raw_text)
+ print('O:' + NSWNormalizer(raw_text).normalize())
+ print('')
+
+
+def nsw_test():
+ nsw_test_case('鍥鸿瘽锛�0595-23865596鎴�23880880銆�')
+ nsw_test_case('鍥鸿瘽锛�0595-23865596鎴�23880880銆�')
+ nsw_test_case('鎵嬫満锛�+86 19859213959鎴�15659451527銆�')
+ nsw_test_case('鍒嗘暟锛�32477/76391銆�')
+ nsw_test_case('鐧惧垎鏁帮細80.03%銆�')
+ nsw_test_case('缂栧彿锛�31520181154418銆�')
+ nsw_test_case('绾暟锛�2983.07鍏嬫垨12345.60绫炽��')
+ nsw_test_case('鏃ユ湡锛�1999骞�2鏈�20鏃ユ垨09骞�3鏈�15鍙枫��')
+ nsw_test_case('閲戦挶锛�12鍧�5锛�34.5鍏冿紝20.1涓�')
+ nsw_test_case('鐗规畩锛歄2O鎴朆2C銆�')
+ nsw_test_case('3456涓囧惃')
+ nsw_test_case('2938涓�')
+ nsw_test_case('938')
+ nsw_test_case('浠婂ぉ鍚冧簡115涓皬绗煎寘231涓澶�')
+ nsw_test_case('鏈�62锛呯殑姒傜巼')
+
+
+if __name__ == '__main__':
+ #nsw_test()
+
+ p = argparse.ArgumentParser()
+ p.add_argument('ifile', help='input filename, assume utf-8 encoding')
+ p.add_argument('ofile', help='output filename')
+ p.add_argument('--to_upper', action='store_true', help='convert to upper case')
+ p.add_argument('--to_lower', action='store_true', help='convert to lower case')
+ p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
+ p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "鍛�, 鍟�"')
+ p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "杩欏効"')
+ p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
+ args = p.parse_args()
+
+ ifile = codecs.open(args.ifile, 'r', 'utf8')
+ ofile = codecs.open(args.ofile, 'w+', 'utf8')
+
+ n = 0
+ for l in ifile:
+ key = ''
+ text = ''
+ if args.has_key:
+ cols = l.split(maxsplit=1)
+ key = cols[0]
+ if len(cols) == 2:
+ text = cols[1].strip()
+ else:
+ text = ''
+ else:
+ text = l.strip()
+
+ # cases
+ if args.to_upper and args.to_lower:
+ sys.stderr.write('text norm: to_upper OR to_lower?')
+ exit(1)
+ if args.to_upper:
+ text = text.upper()
+ if args.to_lower:
+ text = text.lower()
+
+ # Filler chars removal
+ if args.remove_fillers:
+ for ch in FILLER_CHARS:
+ text = text.replace(ch, '')
+
+ if args.remove_erhua:
+ text = remove_erhua(text, ER_WHITELIST)
+
+ # NSW(Non-Standard-Word) normalization
+ text = NSWNormalizer(text).normalize()
+
+ # Punctuations removal
+ old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
+ new_chars = ' ' * len(old_chars)
+ del_chars = ''
+ text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
+
+ #
+ if args.has_key:
+ ofile.write(key + '\t' + text + '\n')
+ else:
+ ofile.write(text + '\n')
+
+ n += 1
+ if n % args.log_interval == 0:
+ sys.stderr.write("text norm: {} lines done.\n".format(n))
+
+ sys.stderr.write("text norm: {} lines done in total.\n".format(n))
+
+ ifile.close()
+ ofile.close()
--
Gitblit v1.9.1