From c568628130ac42ebeea8cf48fe926520a31ff511 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 16 五月 2023 10:57:21 +0800
Subject: [PATCH] update repo

---
 egs/aishell2/transformer/utils/text_tokenize.py |  106 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 106 insertions(+), 0 deletions(-)

diff --git a/egs/aishell2/transformer/utils/text_tokenize.py b/egs/aishell2/transformer/utils/text_tokenize.py
new file mode 100755
index 0000000..962ea11
--- /dev/null
+++ b/egs/aishell2/transformer/utils/text_tokenize.py
@@ -0,0 +1,106 @@
+import re
+import argparse
+
+
+def load_dict(seg_file):
+    seg_dict = {}
+    with open(seg_file, 'r') as infile:
+        for line in infile:
+            s = line.strip().split()
+            key = s[0]
+            value = s[1:]
+            seg_dict[key] = " ".join(value)
+    return seg_dict
+
+
+def forward_segment(text, dic):
+    word_list = []
+    i = 0
+    while i < len(text):
+        longest_word = text[i]
+        for j in range(i + 1, len(text) + 1):
+            word = text[i:j]
+            if word in dic:
+                if len(word) > len(longest_word):
+                    longest_word = word
+        word_list.append(longest_word)
+        i += len(longest_word)
+    return word_list
+
+
+def tokenize(txt,
+             seg_dict):
+    out_txt = ""
+    pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
+    for word in txt:
+        if pattern.match(word):
+            if word in seg_dict:
+                out_txt += seg_dict[word] + " "
+            else:
+                out_txt += "<unk>" + " "
+        else:
+            continue
+    return out_txt.strip()
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        description="text tokenize",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--text-file",
+        "-t",
+        default=False,
+        required=True,
+        type=str,
+        help="input text",
+    )
+    parser.add_argument(
+        "--seg-file",
+        "-s",
+        default=False,
+        required=True,
+        type=str,
+        help="seg file",
+    )
+    parser.add_argument(
+        "--txt-index",
+        "-i",
+        default=1,
+        required=True,
+        type=int,
+        help="txt index",
+    )
+    parser.add_argument(
+        "--output-dir",
+        "-o",
+        default=False,
+        required=True,
+        type=str,
+        help="output dir",
+    )
+    return parser
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), 'w')
+    shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), 'w')
+    seg_dict = load_dict(args.seg_file)
+    with open(args.text_file, 'r') as infile:
+        for line in infile:
+            s = line.strip().split()
+            text_id = s[0]
+            text_list = forward_segment("".join(s[1:]).lower(), seg_dict)
+            text = tokenize(text_list, seg_dict)
+            lens = len(text.strip().split())
+            txt_writer.write(text_id + " " + text + '\n')
+            shape_writer.write(text_id + " " + str(lens) + '\n')
+
+
+if __name__ == '__main__':
+    main()
+

--
Gitblit v1.9.1