From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
fun_text_processing/text_normalization/zh/verbalizers/postprocessor.py | 42 ++++++++++++++++++++++++++----------------
1 files changed, 26 insertions(+), 16 deletions(-)
diff --git a/fun_text_processing/text_normalization/zh/verbalizers/postprocessor.py b/fun_text_processing/text_normalization/zh/verbalizers/postprocessor.py
index 77dc9ff..e87f289 100644
--- a/fun_text_processing/text_normalization/zh/verbalizers/postprocessor.py
+++ b/fun_text_processing/text_normalization/zh/verbalizers/postprocessor.py
@@ -1,5 +1,3 @@
-
-
import pynini
from fun_text_processing.text_normalization.zh.graph_utils import (
FUN_ALPHA,
@@ -14,45 +12,57 @@
class PostProcessor(GraphFst):
- '''
- Postprocessing of TN, now contains:
- 1. punctuation removal
- 2. letter case conversion
- 3. oov tagger
- '''
+ """
+ Postprocessing of TN, now contains:
+ 1. punctuation removal
+ 2. letter case conversion
+ 3. oov tagger
+ """
def __init__(
- self, remove_puncts: bool = False, to_upper: bool = False, to_lower: bool = False, tag_oov: bool = False,
+ self,
+ remove_puncts: bool = False,
+ to_upper: bool = False,
+ to_lower: bool = False,
+ tag_oov: bool = False,
):
super().__init__(name="PostProcessor", kind="processor")
- graph = pynini.cdrewrite('', '', '', FUN_SIGMA)
+ graph = pynini.cdrewrite("", "", "", FUN_SIGMA)
if remove_puncts:
remove_puncts_graph = pynutil.delete(
- pynini.union(FUN_PUNCT, pynini.string_file(get_abs_path('data/char/punctuations_zh.tsv')))
+ pynini.union(
+ FUN_PUNCT, pynini.string_file(get_abs_path("data/char/punctuations_zh.tsv"))
+ )
)
graph @= pynini.cdrewrite(remove_puncts_graph, "", "", FUN_SIGMA).optimize()
if to_upper or to_lower:
if to_upper:
- conv_cases_graph = pynini.inverse(pynini.string_file(get_abs_path('data/char/upper_to_lower.tsv')))
+ conv_cases_graph = pynini.inverse(
+ pynini.string_file(get_abs_path("data/char/upper_to_lower.tsv"))
+ )
else:
- conv_cases_graph = pynini.string_file(get_abs_path('data/char/upper_to_lower.tsv'))
+ conv_cases_graph = pynini.string_file(get_abs_path("data/char/upper_to_lower.tsv"))
graph @= pynini.cdrewrite(conv_cases_graph, "", "", FUN_SIGMA).optimize()
if tag_oov:
- zh_charset_std = pynini.string_file(get_abs_path("data/char/charset_national_standard_2013_8105.tsv"))
+ zh_charset_std = pynini.string_file(
+ get_abs_path("data/char/charset_national_standard_2013_8105.tsv")
+ )
zh_charset_ext = pynini.string_file(get_abs_path("data/char/charset_extension.tsv"))
zh_charset = (
- zh_charset_std | zh_charset_ext | pynini.string_file(get_abs_path("data/char/punctuations_zh.tsv"))
+ zh_charset_std
+ | zh_charset_ext
+ | pynini.string_file(get_abs_path("data/char/punctuations_zh.tsv"))
)
en_charset = FUN_DIGIT | FUN_ALPHA | FUN_PUNCT | FUN_WHITE_SPACE
charset = zh_charset | en_charset
with open(get_abs_path("data/char/oov_tags.tsv"), "r") as f:
- tags = f.readline().strip().split('\t')
+ tags = f.readline().strip().split("\t")
assert len(tags) == 2
ltag, rtag = tags
--
Gitblit v1.9.1