kongdeqiang
7 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
fun_text_processing/text_normalization/en/verbalizers/post_processing.py
@@ -1,6 +1,3 @@
import os
import pynini
@@ -36,7 +33,7 @@
            far_file = os.path.join(cache_dir, "en_tn_post_processing.far")
        if not overwrite_cache and far_file and os.path.exists(far_file):
            self.fst = pynini.Far(far_file, mode="r")["post_process_graph"]
            logging.info(f'Post processing graph was restored from {far_file}.')
            logging.info(f"Post processing graph was restored from {far_file}.")
        else:
            self.set_punct_dict()
            self.fst = self.get_punct_postprocess_graph()
@@ -48,55 +45,55 @@
        self.punct_marks = {
            "'": [
                "'",
                '´',
                'ʹ',
                'ʻ',
                'ʼ',
                'ʽ',
                'ʾ',
                'ˈ',
                'ˊ',
                'ˋ',
                '˴',
                'ʹ',
                '΄',
                '՚',
                '՝',
                'י',
                '׳',
                'ߴ',
                'ߵ',
                'ᑊ',
                'ᛌ',
                '᾽',
                '᾿',
                '`',
                '´',
                '῾',
                '‘',
                '’',
                '‛',
                '′',
                '‵',
                'ꞌ',
                ''',
                '`',
                '𖽑',
                '𖽒',
                "´",
                "ʹ",
                "ʻ",
                "ʼ",
                "ʽ",
                "ʾ",
                "ˈ",
                "ˊ",
                "ˋ",
                "˴",
                "ʹ",
                "΄",
                "՚",
                "՝",
                "י",
                "׳",
                "ߴ",
                "ߵ",
                "ᑊ",
                "ᛌ",
                "᾽",
                "᾿",
                "`",
                "´",
                "῾",
                "‘",
                "’",
                "‛",
                "′",
                "‵",
                "ꞌ",
                "'",
                "`",
                "𖽑",
                "𖽒",
            ],
        }
    def get_punct_postprocess_graph(self):
        """
            Returns graph to post process punctuation marks.
        Returns graph to post process punctuation marks.
            {``} quotes are converted to {"}. Note, if there are spaces around single quote {'}, they will be kept.
            By default, a space is added after a punctuation mark, and spaces are removed before punctuation marks.
        {``} quotes are converted to {"}. Note, if there are spaces around single quote {'}, they will be kept.
        By default, a space is added after a punctuation mark, and spaces are removed before punctuation marks.
        """
        punct_marks_all = PunctuationFst().punct_marks
        # no_space_before_punct assume no space before them
        quotes = ["'", "\"", "``", "«"]
        quotes = ["'", '"', "``", "«"]
        dashes = ["-", "—"]
        brackets = ["<", "{", "("]
        open_close_single_quotes = [
@@ -105,7 +102,9 @@
        open_close_double_quotes = [('"', '"'), ("``", "``"), ("“", "”")]
        open_close_symbols = open_close_single_quotes + open_close_double_quotes
        allow_space_before_punct = ["&"] + quotes + dashes + brackets + [k[0] for k in open_close_symbols]
        allow_space_before_punct = (
            ["&"] + quotes + dashes + brackets + [k[0] for k in open_close_symbols]
        )
        no_space_before_punct = [m for m in punct_marks_all if m not in allow_space_before_punct]
        no_space_before_punct = pynini.union(*no_space_before_punct)
@@ -119,7 +118,8 @@
        graph = (
            pynini.closure(non_punct)
            + pynini.closure(
                no_space_before_punct | pynutil.add_weight(delete_space + no_space_before_punct, MIN_NEG_WEIGHT)
                no_space_before_punct
                | pynutil.add_weight(delete_space + no_space_before_punct, MIN_NEG_WEIGHT)
            )
            + pynini.closure(non_punct)
        )
@@ -129,14 +129,21 @@
        ).optimize()
        # remove space after no_space_after_punct (even if there are no matching closing brackets)
        no_space_after_punct = pynini.cdrewrite(delete_space, no_space_after_punct, DAMO_SIGMA, DAMO_SIGMA).optimize()
        no_space_after_punct = pynini.cdrewrite(
            delete_space, no_space_after_punct, DAMO_SIGMA, DAMO_SIGMA
        ).optimize()
        graph = pynini.compose(graph, no_space_after_punct).optimize()
        # remove space around text in quotes
        single_quote = pynutil.add_weight(pynini.accep("`"), MIN_NEG_WEIGHT)
        double_quotes = pynutil.add_weight(pynini.accep('"'), MIN_NEG_WEIGHT)
        quotes_graph = (
            single_quote + delete_space_optional + DAMO_ALPHA + DAMO_SIGMA + delete_space_optional + single_quote
            single_quote
            + delete_space_optional
            + DAMO_ALPHA
            + DAMO_SIGMA
            + delete_space_optional
            + single_quote
        ).optimize()
        # this is to make sure multiple quotes are tagged from right to left without skipping any quotes in the left