kongdeqiang
6 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import pynini
from fun_text_processing.text_normalization.en.graph_utils import (
    DAMO_DIGIT,
    GraphFst,
    convert_space,
)
from pynini.lib import pynutil
 
 
class RangeFst(GraphFst):
    """
    This class is a composite class of two other class instances
 
    Args:
        time: composed tagger and verbalizer
        date: composed tagger and verbalizer
        cardinal: tagger
        deterministic: if True will provide a single transduction option,
        for False multiple transduction are generated (used for audio-based normalization)
        lm: whether to use for hybrid LM
    """
 
    def __init__(
        self,
        time: GraphFst,
        date: GraphFst,
        cardinal: GraphFst,
        deterministic: bool = True,
        lm: bool = False,
    ):
        super().__init__(name="range", kind="classify", deterministic=deterministic)
 
        delete_space = pynini.closure(pynutil.delete(" "), 0, 1)
 
        approx = pynini.cross("~", "approximately")
 
        # TIME
        time_graph = time + delete_space + pynini.cross("-", " to ") + delete_space + time
        self.graph = time_graph | (approx + time)
 
        cardinal = cardinal.graph_with_and
        # YEAR
        date_year_four_digit = (DAMO_DIGIT**4 + pynini.closure(pynini.accep("s"), 0, 1)) @ date
        date_year_two_digit = (DAMO_DIGIT**2 + pynini.closure(pynini.accep("s"), 0, 1)) @ date
        year_to_year_graph = (
            date_year_four_digit
            + delete_space
            + pynini.cross("-", " to ")
            + delete_space
            + (date_year_four_digit | date_year_two_digit | (DAMO_DIGIT**2 @ cardinal))
        )
        mid_year_graph = (
            pynini.accep("mid")
            + pynini.cross("-", " ")
            + (date_year_four_digit | date_year_two_digit)
        )
 
        self.graph |= year_to_year_graph
        self.graph |= mid_year_graph
 
        # ADDITION
        range_graph = cardinal + pynini.closure(pynini.cross("+", " plus ") + cardinal, 1)
        range_graph |= cardinal + pynini.closure(pynini.cross(" + ", " plus ") + cardinal, 1)
        range_graph |= approx + cardinal
        range_graph |= cardinal + (pynini.cross("...", " ... ") | pynini.accep(" ... ")) + cardinal
 
        if not deterministic or lm:
            # cardinal ----
            cardinal_to_cardinal_graph = (
                cardinal
                + delete_space
                + pynini.cross("-", pynini.union(" to ", " minus "))
                + delete_space
                + cardinal
            )
 
            range_graph |= cardinal_to_cardinal_graph | (
                cardinal + delete_space + pynini.cross(":", " to ") + delete_space + cardinal
            )
 
            # MULTIPLY
            for x in [" x ", "x"]:
                range_graph |= cardinal + pynini.closure(
                    pynini.cross(x, pynini.union(" by ", " times ")) + cardinal, 1
                )
 
            for x in ["*", " * "]:
                range_graph |= cardinal + pynini.closure(pynini.cross(x, " times ") + cardinal, 1)
 
            # supports "No. 12" -> "Number 12"
            range_graph |= (
                (pynini.cross(pynini.union("NO", "No"), "Number") | pynini.cross("no", "number"))
                + pynini.closure(pynini.union(". ", " "), 0, 1)
                + cardinal
            )
 
            for x in ["/", " / "]:
                range_graph |= cardinal + pynini.closure(
                    pynini.cross(x, " divided by ") + cardinal, 1
                )
 
        self.graph |= range_graph
 
        self.graph = self.graph.optimize()
        graph = (
            pynutil.insert('name: "') + convert_space(self.graph).optimize() + pynutil.insert('"')
        )
        self.fst = graph.optimize()