import os
|
|
import pynini
|
from fun_text_processing.text_normalization.en.graph_utils import (
|
MIN_NEG_WEIGHT,
|
DAMO_ALPHA,
|
DAMO_CHAR,
|
DAMO_SIGMA,
|
DAMO_SPACE,
|
generator_main,
|
)
|
from fun_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
|
from pynini.lib import pynutil
|
|
import logging
|
|
|
class PostProcessingFst:
|
"""
|
Finite state transducer that post-processing an entire sentence after verbalization is complete, e.g.
|
removes extra spaces around punctuation marks " ( one hundred and twenty three ) " -> "(one hundred and twenty three)"
|
|
Args:
|
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
|
overwrite_cache: set to True to overwrite .far files
|
"""
|
|
def __init__(self, cache_dir: str = None, overwrite_cache: bool = False):
|
|
far_file = None
|
if cache_dir is not None and cache_dir != "None":
|
os.makedirs(cache_dir, exist_ok=True)
|
far_file = os.path.join(cache_dir, "en_tn_post_processing.far")
|
if not overwrite_cache and far_file and os.path.exists(far_file):
|
self.fst = pynini.Far(far_file, mode="r")["post_process_graph"]
|
logging.info(f"Post processing graph was restored from {far_file}.")
|
else:
|
self.set_punct_dict()
|
self.fst = self.get_punct_postprocess_graph()
|
|
if far_file:
|
generator_main(far_file, {"post_process_graph": self.fst})
|
|
def set_punct_dict(self):
|
self.punct_marks = {
|
"'": [
|
"'",
|
"´",
|
"ʹ",
|
"ʻ",
|
"ʼ",
|
"ʽ",
|
"ʾ",
|
"ˈ",
|
"ˊ",
|
"ˋ",
|
"˴",
|
"ʹ",
|
"΄",
|
"՚",
|
"՝",
|
"י",
|
"׳",
|
"ߴ",
|
"ߵ",
|
"ᑊ",
|
"ᛌ",
|
"᾽",
|
"᾿",
|
"`",
|
"´",
|
"῾",
|
"‘",
|
"’",
|
"‛",
|
"′",
|
"‵",
|
"ꞌ",
|
"'",
|
"`",
|
"𖽑",
|
"𖽒",
|
],
|
}
|
|
def get_punct_postprocess_graph(self):
|
"""
|
Returns graph to post process punctuation marks.
|
|
{``} quotes are converted to {"}. Note, if there are spaces around single quote {'}, they will be kept.
|
By default, a space is added after a punctuation mark, and spaces are removed before punctuation marks.
|
"""
|
punct_marks_all = PunctuationFst().punct_marks
|
|
# no_space_before_punct assume no space before them
|
quotes = ["'", '"', "``", "«"]
|
dashes = ["-", "—"]
|
brackets = ["<", "{", "("]
|
open_close_single_quotes = [
|
("`", "`"),
|
]
|
|
open_close_double_quotes = [('"', '"'), ("``", "``"), ("“", "”")]
|
open_close_symbols = open_close_single_quotes + open_close_double_quotes
|
allow_space_before_punct = (
|
["&"] + quotes + dashes + brackets + [k[0] for k in open_close_symbols]
|
)
|
|
no_space_before_punct = [m for m in punct_marks_all if m not in allow_space_before_punct]
|
no_space_before_punct = pynini.union(*no_space_before_punct)
|
no_space_after_punct = pynini.union(*brackets)
|
delete_space = pynutil.delete(" ")
|
delete_space_optional = pynini.closure(delete_space, 0, 1)
|
|
# non_punct allows space
|
# delete space before no_space_before_punct marks, if present
|
non_punct = pynini.difference(DAMO_CHAR, no_space_before_punct).optimize()
|
graph = (
|
pynini.closure(non_punct)
|
+ pynini.closure(
|
no_space_before_punct
|
| pynutil.add_weight(delete_space + no_space_before_punct, MIN_NEG_WEIGHT)
|
)
|
+ pynini.closure(non_punct)
|
)
|
graph = pynini.closure(graph).optimize()
|
graph = pynini.compose(
|
graph, pynini.cdrewrite(pynini.cross("``", '"'), "", "", DAMO_SIGMA).optimize()
|
).optimize()
|
|
# remove space after no_space_after_punct (even if there are no matching closing brackets)
|
no_space_after_punct = pynini.cdrewrite(
|
delete_space, no_space_after_punct, DAMO_SIGMA, DAMO_SIGMA
|
).optimize()
|
graph = pynini.compose(graph, no_space_after_punct).optimize()
|
|
# remove space around text in quotes
|
single_quote = pynutil.add_weight(pynini.accep("`"), MIN_NEG_WEIGHT)
|
double_quotes = pynutil.add_weight(pynini.accep('"'), MIN_NEG_WEIGHT)
|
quotes_graph = (
|
single_quote
|
+ delete_space_optional
|
+ DAMO_ALPHA
|
+ DAMO_SIGMA
|
+ delete_space_optional
|
+ single_quote
|
).optimize()
|
|
# this is to make sure multiple quotes are tagged from right to left without skipping any quotes in the left
|
not_alpha = pynini.difference(DAMO_CHAR, DAMO_ALPHA).optimize() | pynutil.add_weight(
|
DAMO_SPACE, MIN_NEG_WEIGHT
|
)
|
end = pynini.closure(pynutil.add_weight(not_alpha, MIN_NEG_WEIGHT))
|
quotes_graph |= (
|
double_quotes
|
+ delete_space_optional
|
+ DAMO_ALPHA
|
+ DAMO_SIGMA
|
+ delete_space_optional
|
+ double_quotes
|
+ end
|
)
|
|
quotes_graph = pynutil.add_weight(quotes_graph, MIN_NEG_WEIGHT)
|
quotes_graph = DAMO_SIGMA + pynini.closure(DAMO_SIGMA + quotes_graph + DAMO_SIGMA)
|
|
graph = pynini.compose(graph, quotes_graph).optimize()
|
|
# remove space between a word and a single quote followed by s
|
remove_space_around_single_quote = pynini.cdrewrite(
|
delete_space_optional + pynini.union(*self.punct_marks["'"]) + delete_space,
|
DAMO_ALPHA,
|
pynini.union("s ", "s[EOS]"),
|
DAMO_SIGMA,
|
)
|
|
graph = pynini.compose(graph, remove_space_around_single_quote).optimize()
|
return graph
|