slin000111
2025-09-20 88b464edb94531ba98744a5338965af539ab18ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import json
import re
import string
from collections import defaultdict, namedtuple
from typing import Dict, List, Optional, Set, Tuple
from unicodedata import category
 
import logging
 
 
EOS_TYPE = "EOS"
PUNCT_TYPE = "PUNCT"
PLAIN_TYPE = "PLAIN"
Instance = namedtuple("Instance", "token_type un_normalized normalized")
known_types = [
    "PLAIN",
    "DATE",
    "CARDINAL",
    "LETTERS",
    "VERBATIM",
    "MEASURE",
    "DECIMAL",
    "ORDINAL",
    "DIGIT",
    "MONEY",
    "TELEPHONE",
    "ELECTRONIC",
    "FRACTION",
    "TIME",
    "ADDRESS",
]
 
 
def _load_kaggle_text_norm_file(file_path: str) -> List[Instance]:
    """
    https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
    Loads text file in the Kaggle Google text normalization file format: <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
    E.g.
    PLAIN   Brillantaisia   <self>
    PLAIN   is      <self>
    PLAIN   a       <self>
    PLAIN   genus   <self>
    PLAIN   of      <self>
    PLAIN   plant   <self>
    PLAIN   in      <self>
    PLAIN   family  <self>
    PLAIN   Acanthaceae     <self>
    PUNCT   .       sil
    <eos>   <eos>
 
    Args:
        file_path: file path to text file
 
    Returns: flat list of instances
    """
    res = []
    with open(file_path, "r") as fp:
        for line in fp:
            parts = line.strip().split("\t")
            if parts[0] == "<eos>":
                res.append(Instance(token_type=EOS_TYPE, un_normalized="", normalized=""))
            else:
                l_type, l_token, l_normalized = parts
                l_token = l_token.lower()
                l_normalized = l_normalized.lower()
 
                if l_type == PLAIN_TYPE:
                    res.append(
                        Instance(token_type=l_type, un_normalized=l_token, normalized=l_token)
                    )
                elif l_type != PUNCT_TYPE:
                    res.append(
                        Instance(token_type=l_type, un_normalized=l_token, normalized=l_normalized)
                    )
    return res
 
 
def load_files(file_paths: List[str], load_func=_load_kaggle_text_norm_file) -> List[Instance]:
    """
    Load given list of text files using the `load_func` function.
 
    Args:
        file_paths: list of file paths
        load_func: loading function
 
    Returns: flat list of instances
    """
    res = []
    for file_path in file_paths:
        res.extend(load_func(file_path=file_path))
    return res
 
 
def clean_generic(text: str) -> str:
    """
    Cleans text without affecting semiotic classes.
 
    Args:
        text: string
 
    Returns: cleaned string
    """
    text = text.strip()
    text = text.lower()
    return text
 
 
def evaluate(
    preds: List[str], labels: List[str], input: Optional[List[str]] = None, verbose: bool = True
) -> float:
    """
    Evaluates accuracy given predictions and labels.
 
    Args:
        preds: predictions
        labels: labels
        input: optional, only needed for verbosity
        verbose: if true prints [input], golden labels and predictions
 
    Returns accuracy
    """
    acc = 0
    nums = len(preds)
    for i in range(nums):
        pred_norm = clean_generic(preds[i])
        label_norm = clean_generic(labels[i])
        if pred_norm == label_norm:
            acc = acc + 1
        else:
            if input:
                print(f"inpu: {json.dumps(input[i])}")
            print(f"gold: {json.dumps(label_norm)}")
            print(f"pred: {json.dumps(pred_norm)}")
    return acc / nums
 
 
def training_data_to_tokens(
    data: List[Instance], category: Optional[str] = None
) -> Dict[str, Tuple[List[str], List[str]]]:
    """
    Filters the instance list by category if provided and converts it into a map from token type to list of un_normalized and normalized strings
 
    Args:
        data: list of instances
        category: optional semiotic class category name
 
    Returns Dict: token type -> (list of un_normalized strings, list of normalized strings)
    """
    result = defaultdict(lambda: ([], []))
    for instance in data:
        if instance.token_type != EOS_TYPE:
            if category is None or instance.token_type == category:
                result[instance.token_type][0].append(instance.un_normalized)
                result[instance.token_type][1].append(instance.normalized)
    return result
 
 
def training_data_to_sentences(data: List[Instance]) -> Tuple[List[str], List[str], List[Set[str]]]:
    """
    Takes instance list, creates list of sentences split by EOS_Token
    Args:
        data: list of instances
    Returns (list of unnormalized sentences, list of normalized sentences, list of sets of categories in a sentence)
    """
    # split data at EOS boundaries
    sentences = []
    sentence = []
    categories = []
    sentence_categories = set()
 
    for instance in data:
        if instance.token_type == EOS_TYPE:
            sentences.append(sentence)
            sentence = []
            categories.append(sentence_categories)
            sentence_categories = set()
        else:
            sentence.append(instance)
            sentence_categories.update([instance.token_type])
    un_normalized = [
        " ".join([instance.un_normalized for instance in sentence]) for sentence in sentences
    ]
    normalized = [
        " ".join([instance.normalized for instance in sentence]) for sentence in sentences
    ]
    return un_normalized, normalized, categories
 
 
def post_process_punctuation(text: str) -> str:
    """
    Normalized quotes and spaces
 
    Args:
        text: text
 
    Returns: text with normalized spaces and quotes
    """
    text = (
        text.replace("( ", "(")
        .replace(" )", ")")
        .replace("{ ", "{")
        .replace(" }", "}")
        .replace("[ ", "[")
        .replace(" ]", "]")
        .replace("  ", " ")
        .replace("”", '"')
        .replace("’", "'")
        .replace("»", '"')
        .replace("«", '"')
        .replace("\\", "")
        .replace("„", '"')
        .replace("´", "'")
        .replace("’", "'")
        .replace("“", '"')
        .replace("‘", "'")
        .replace("`", "'")
        .replace("- -", "--")
    )
 
    for punct in "!,.:;?":
        text = text.replace(f" {punct}", punct)
    return text.strip()
 
 
def pre_process(text: str) -> str:
    """
    Optional text preprocessing before normalization (part of TTS TN pipeline)
 
    Args:
        text: string that may include semiotic classes
 
    Returns: text with spaces around punctuation marks
    """
    space_both = "[]"
    for punct in space_both:
        text = text.replace(punct, " " + punct + " ")
 
    # remove extra space
    text = re.sub(r" +", " ", text)
    return text
 
 
def load_file(file_path: str) -> List[str]:
    """
    Loads given text file with separate lines into list of string.
 
    Args:
        file_path: file path
 
    Returns: flat list of string
    """
    res = []
    with open(file_path, "r") as fp:
        for line in fp:
            res.append(line)
    return res
 
 
def write_file(file_path: str, data: List[str]):
    """
    Writes out list of string to file.
 
    Args:
        file_path: file path
        data: list of string
 
    """
    with open(file_path, "w") as fp:
        for line in data:
            fp.write(line + "\n")
 
 
def post_process_punct(input: str, normalized_text: str, add_unicode_punct: bool = False):
    """
    Post-processing of the normalized output to match input in terms of spaces around punctuation marks.
    After NN normalization, Moses detokenization puts a space after
    punctuation marks, and attaches an opening quote "'" to the word to the right.
    E.g., input to the TN NN model is "12 test' example",
    after normalization and detokenization -> "twelve test 'example" (the quote is considered to be an opening quote,
    but it doesn't match the input and can cause issues during TTS voice generation.)
    The current function will match the punctuation and spaces of the normalized text with the input sequence.
    "12 test' example" -> "twelve test 'example" -> "twelve test' example" (the quote was shifted to match the input).
 
    Args:
        input: input text (original input to the NN, before normalization or tokenization)
        normalized_text: output text (output of the TN NN model)
        add_unicode_punct: set to True to handle unicode punctuation marks as well as default string.punctuation (increases post processing time)
    """
    # in the post-processing WFST graph "``" are repalced with '"" quotes (otherwise single quotes "`" won't be handled correctly)
    # this function fixes spaces around them based on input sequence, so here we're making the same double quote replacement
    # to make sure these new double quotes work with this function
    if "``" in input and "``" not in normalized_text:
        input = input.replace("``", '"')
    input = [x for x in input]
    normalized_text = [x for x in normalized_text]
    punct_marks = [x for x in string.punctuation if x in input]
 
    if add_unicode_punct:
        punct_unicode = [
            chr(i)
            for i in range(sys.maxunicode)
            if category(chr(i)).startswith("P") and chr(i) not in punct_default and chr(i) in input
        ]
        punct_marks = punct_marks.extend(punct_unicode)
 
    for punct in punct_marks:
        try:
            equal = True
            if input.count(punct) != normalized_text.count(punct):
                equal = False
            idx_in, idx_out = 0, 0
            while punct in input[idx_in:]:
                idx_out = normalized_text.index(punct, idx_out)
                idx_in = input.index(punct, idx_in)
 
                def _is_valid(idx_out, idx_in, normalized_text, input):
                    """Check if previous or next word match (for cases when punctuation marks are part of
                    semiotic token, i.e. some punctuation can be missing in the normalized text)"""
                    return (
                        idx_out > 0
                        and idx_in > 0
                        and normalized_text[idx_out - 1] == input[idx_in - 1]
                    ) or (
                        idx_out < len(normalized_text) - 1
                        and idx_in < len(input) - 1
                        and normalized_text[idx_out + 1] == input[idx_in + 1]
                    )
 
                if not equal and not _is_valid(idx_out, idx_in, normalized_text, input):
                    idx_in += 1
                    continue
                if idx_in > 0 and idx_out > 0:
                    if normalized_text[idx_out - 1] == " " and input[idx_in - 1] != " ":
                        normalized_text[idx_out - 1] = ""
 
                    elif normalized_text[idx_out - 1] != " " and input[idx_in - 1] == " ":
                        normalized_text[idx_out - 1] += " "
 
                if idx_in < len(input) - 1 and idx_out < len(normalized_text) - 1:
                    if normalized_text[idx_out + 1] == " " and input[idx_in + 1] != " ":
                        normalized_text[idx_out + 1] = ""
                    elif normalized_text[idx_out + 1] != " " and input[idx_in + 1] == " ":
                        normalized_text[idx_out] = normalized_text[idx_out] + " "
                idx_out += 1
                idx_in += 1
        except:
            logging.debug(f"Skipping post-processing of {''.join(normalized_text)} for '{punct}'")
 
    normalized_text = "".join(normalized_text)
    return re.sub(r" +", " ", normalized_text)