BienBoy
2025-02-01 c1e365fea09aafda387cac12fdff43d28c598979
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import pynini
from fun_text_processing.text_normalization.en.graph_utils import (
    DAMO_ALPHA,
    DAMO_DIGIT,
    DAMO_NOT_SPACE,
    DAMO_SIGMA,
    GraphFst,
    convert_space,
)
from fun_text_processing.text_normalization.en.utils import get_abs_path, load_labels
from pynini.examples import plurals
from pynini.lib import pynutil
 
 
class SerialFst(GraphFst):
    """
    This class is a composite class of two other class instances
 
    Args:
        time: composed tagger and verbalizer
        date: composed tagger and verbalizer
        cardinal: tagger
        deterministic: if True will provide a single transduction option,
        for False multiple transduction are generated (used for audio-based normalization)
        lm: whether to use for hybrid LM
    """
 
    def __init__(
        self, cardinal: GraphFst, ordinal: GraphFst, deterministic: bool = True, lm: bool = False
    ):
        super().__init__(name="integer", kind="classify", deterministic=deterministic)
 
        """
        Finite state transducer for classifying serial (handles only cases without delimiters,
        values with delimiters are handled by default).
            The serial is a combination of digits, letters and dashes, e.g.:
            c325b -> tokens { cardinal { integer: "c three two five b" } }
        """
        num_graph = pynini.compose(DAMO_DIGIT ** (6, ...), cardinal.single_digits_graph).optimize()
        num_graph |= pynini.compose(DAMO_DIGIT ** (1, 5), cardinal.graph).optimize()
        # to handle numbers starting with zero
        num_graph |= pynini.compose(
            pynini.accep("0") + pynini.closure(DAMO_DIGIT), cardinal.single_digits_graph
        ).optimize()
        # TODO: "#" doesn't work from the file
        symbols_graph = pynini.string_file(
            get_abs_path("data/whitelist/symbol.tsv")
        ).optimize() | pynini.cross("#", "hash")
        num_graph |= symbols_graph
 
        if not self.deterministic and not lm:
            num_graph |= cardinal.single_digits_graph
            # also allow double digits to be pronounced as integer in serial number
            num_graph |= pynutil.add_weight(
                DAMO_DIGIT**2 @ cardinal.graph_hundred_component_at_least_one_none_zero_digit,
                weight=0.0001,
            )
 
        # add space between letter and digit/symbol
        symbols = [x[0] for x in load_labels(get_abs_path("data/whitelist/symbol.tsv"))]
        symbols = pynini.union(*symbols)
        digit_symbol = DAMO_DIGIT | symbols
 
        graph_with_space = pynini.compose(
            pynini.cdrewrite(pynutil.insert(" "), DAMO_ALPHA | symbols, digit_symbol, DAMO_SIGMA),
            pynini.cdrewrite(pynutil.insert(" "), digit_symbol, DAMO_ALPHA | symbols, DAMO_SIGMA),
        )
 
        # serial graph with delimiter
        delimiter = pynini.accep("-") | pynini.accep("/") | pynini.accep(" ")
        if not deterministic:
            delimiter |= pynini.cross("-", " dash ") | pynini.cross("/", " slash ")
 
        alphas = pynini.closure(DAMO_ALPHA, 1)
        letter_num = alphas + delimiter + num_graph
        num_letter = pynini.closure(num_graph + delimiter, 1) + alphas
        next_alpha_or_num = pynini.closure(delimiter + (alphas | num_graph))
        next_alpha_or_num |= pynini.closure(
            delimiter
            + num_graph
            + plurals._priority_union(pynini.accep(" "), pynutil.insert(" "), DAMO_SIGMA).optimize()
            + alphas
        )
 
        serial_graph = letter_num + next_alpha_or_num
        serial_graph |= num_letter + next_alpha_or_num
        # numbers only with 2+ delimiters
        serial_graph |= (
            num_graph
            + delimiter
            + num_graph
            + delimiter
            + num_graph
            + pynini.closure(delimiter + num_graph)
        )
        # 2+ symbols
        serial_graph |= pynini.compose(
            DAMO_SIGMA + symbols + DAMO_SIGMA, num_graph + delimiter + num_graph
        )
 
        # exclude ordinal numbers from serial options
        serial_graph = pynini.compose(
            pynini.difference(DAMO_SIGMA, pynini.project(ordinal.graph, "input")), serial_graph
        ).optimize()
 
        serial_graph = pynutil.add_weight(serial_graph, 0.0001)
        serial_graph |= (
            pynini.closure(DAMO_NOT_SPACE, 1)
            + (pynini.cross("^2", " squared") | pynini.cross("^3", " cubed")).optimize()
        )
 
        # at least one serial graph with alpha numeric value and optional additional serial/num/alpha values
        serial_graph = (
            pynini.closure((serial_graph | num_graph | alphas) + delimiter)
            + serial_graph
            + pynini.closure(delimiter + (serial_graph | num_graph | alphas))
        )
 
        serial_graph |= pynini.compose(graph_with_space, serial_graph.optimize()).optimize()
        serial_graph = pynini.compose(pynini.closure(DAMO_NOT_SPACE, 2), serial_graph).optimize()
 
        # this is not to verbolize "/" as "slash" in cases like "import/export"
        serial_graph = pynini.compose(
            pynini.difference(
                DAMO_SIGMA,
                pynini.closure(DAMO_ALPHA, 1) + pynini.accep("/") + pynini.closure(DAMO_ALPHA, 1),
            ),
            serial_graph,
        )
        self.graph = serial_graph.optimize()
        graph = (
            pynutil.insert('name: "') + convert_space(self.graph).optimize() + pynutil.insert('"')
        )
        self.fst = graph.optimize()