nianjiuhuiyi
2025-03-20 0d6e6a2fb8581f21002f8c25ad5413e8ec694386
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
import itertools
import os
import re
from argparse import ArgumentParser
from collections import OrderedDict
from math import factorial
from time import perf_counter
from typing import Dict, List, Union
 
import pynini
import regex
from joblib import Parallel, delayed
from fun_text_processing.text_normalization.data_loader_utils import (
    load_file,
    post_process_punct,
    pre_process,
    write_file,
)
from fun_text_processing.text_normalization.token_parser import PRESERVE_ORDER_KEY, TokenParser
from pynini.lib.rewrite import top_rewrite
from tqdm import tqdm
 
try:
    from nemo.collections.common.tokenizers.moses_tokenizers import MosesProcessor
 
    NLP_AVAILABLE = True
except (ModuleNotFoundError, ImportError) as e:
    NLP_AVAILABLE = False
 
SPACE_DUP = re.compile(" {2,}")
 
 
class Normalizer:
    """
    Normalizer class that converts text from written to spoken form.
    Useful for TTS preprocessing.
 
    Args:
        input_case: expected input capitalization
        lang: language specifying the TN rules, by default: English
        cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
        overwrite_cache: set to True to overwrite .far files
        whitelist: path to a file with whitelist replacements
        post_process: WFST-based post processing, e.g. to remove extra spaces added during TN.
            Note: punct_post_process flag in normalize() supports all languages.
    """
 
    def __init__(
        self,
        input_case: str,
        lang: str = "en",
        deterministic: bool = True,
        cache_dir: str = None,
        overwrite_cache: bool = False,
        whitelist: str = None,
        lm: bool = False,
        post_process: bool = True,
    ):
        assert input_case in ["lower_cased", "cased"]
 
        self.post_processor = None
 
        if lang == "en":
            from fun_text_processing.text_normalization.en.verbalizers.verbalize_final import (
                VerbalizeFinalFst,
            )
            from fun_text_processing.text_normalization.en.verbalizers.post_processing import (
                PostProcessingFst,
            )
 
            if post_process:
                self.post_processor = PostProcessingFst(
                    cache_dir=cache_dir, overwrite_cache=overwrite_cache
                )
 
            if deterministic:
                from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify import (
                    ClassifyFst,
                )
            else:
                if lm:
                    from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify_lm import (
                        ClassifyFst,
                    )
                else:
                    from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify_with_audio import (
                        ClassifyFst,
                    )
 
        elif lang == "ru":
            # Ru TN only support non-deterministic cases and produces multiple normalization options
            # use normalize_with_audio.py
            from fun_text_processing.text_normalization.ru.taggers.tokenize_and_classify import (
                ClassifyFst,
            )
            from fun_text_processing.text_normalization.ru.verbalizers.verbalize_final import (
                VerbalizeFinalFst,
            )
        elif lang == "de":
            from fun_text_processing.text_normalization.de.taggers.tokenize_and_classify import (
                ClassifyFst,
            )
            from fun_text_processing.text_normalization.de.verbalizers.verbalize_final import (
                VerbalizeFinalFst,
            )
        elif lang == "es":
            from fun_text_processing.text_normalization.es.taggers.tokenize_and_classify import (
                ClassifyFst,
            )
            from fun_text_processing.text_normalization.es.verbalizers.verbalize_final import (
                VerbalizeFinalFst,
            )
        elif lang == "zh":
            from fun_text_processing.text_normalization.zh.taggers.tokenize_and_classify import (
                ClassifyFst,
            )
            from fun_text_processing.text_normalization.zh.verbalizers.verbalize_final import (
                VerbalizeFinalFst,
            )
        self.tagger = ClassifyFst(
            input_case=input_case,
            deterministic=deterministic,
            cache_dir=cache_dir,
            overwrite_cache=overwrite_cache,
            whitelist=whitelist,
        )
 
        self.verbalizer = VerbalizeFinalFst(
            deterministic=deterministic, cache_dir=cache_dir, overwrite_cache=overwrite_cache
        )
 
        self.parser = TokenParser()
        self.lang = lang
 
        if NLP_AVAILABLE:
            self.processor = MosesProcessor(lang_id=lang)
        else:
            self.processor = None
            print("NeMo NLP is not available. Moses de-tokenization will be skipped.")
 
    def normalize_list(
        self,
        texts: List[str],
        verbose: bool = False,
        punct_pre_process: bool = False,
        punct_post_process: bool = False,
        batch_size: int = 1,
        n_jobs: int = 1,
    ):
        """
        NeMo text normalizer
 
        Args:
            texts: list of input strings
            verbose: whether to print intermediate meta information
            punct_pre_process: whether to do punctuation pre processing
            punct_post_process: whether to do punctuation post processing
            n_jobs: the maximum number of concurrently running jobs. If -1 all CPUs are used. If 1 is given,
                no parallel computing code is used at all, which is useful for debugging. For n_jobs below -1,
                (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one are used.
            batch_size: Number of examples for each process
 
        Returns converted list input strings
        """
 
        # to save intermediate results to a file
        batch = min(len(texts), batch_size)
 
        try:
            normalized_texts = Parallel(n_jobs=n_jobs)(
                delayed(self.process_batch)(
                    texts[i : i + batch], verbose, punct_pre_process, punct_post_process
                )
                for i in range(0, len(texts), batch)
            )
        except BaseException as e:
            raise e
 
        normalized_texts = list(itertools.chain(*normalized_texts))
        return normalized_texts
 
    def process_batch(self, batch, verbose, punct_pre_process, punct_post_process):
        """
        Normalizes batch of text sequences
        Args:
            batch: list of texts
            verbose: whether to print intermediate meta information
            punct_pre_process: whether to do punctuation pre processing
            punct_post_process: whether to do punctuation post processing
        """
        normalized_lines = [
            self.normalize(
                text,
                verbose=verbose,
                punct_pre_process=punct_pre_process,
                punct_post_process=punct_post_process,
            )
            for text in tqdm(batch)
        ]
        return normalized_lines
 
    def _estimate_number_of_permutations_in_nested_dict(
        self, token_group: Dict[str, Union[OrderedDict, str, bool]]
    ) -> int:
        num_perms = 1
        for k, inner in token_group.items():
            if isinstance(inner, dict):
                num_perms *= self._estimate_number_of_permutations_in_nested_dict(inner)
        num_perms *= factorial(len(token_group))
        return num_perms
 
    def _split_tokens_to_reduce_number_of_permutations(
        self, tokens: List[dict], max_number_of_permutations_per_split: int = 729
    ) -> List[List[dict]]:
        """
        Splits a sequence of tokens in a smaller sequences of tokens in a way that maximum number of composite
        tokens permutations does not exceed ``max_number_of_permutations_per_split``.
 
        For example,
 
        .. code-block:: python
            tokens = [
                {"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}},
                {"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}},
            ]
            split = normalizer._split_tokens_to_reduce_number_of_permutations(
                tokens, max_number_of_permutations_per_split=6
            )
            assert split == [
                [{"tokens": {"date": {"year": "twenty eighteen", "month": "december", "day": "thirty one"}}}],
                [{"tokens": {"date": {"year": "twenty eighteen", "month": "january", "day": "eight"}}}],
            ]
 
        Date tokens contain 3 items each which gives 6 permutations for every date. Since there are 2 dates, total
        number of permutations would be ``6 * 6 == 36``. Parameter ``max_number_of_permutations_per_split`` equals 6,
        so input sequence of tokens is split into 2 smaller sequences.
 
        Args:
            tokens (:obj:`List[dict]`): a list of dictionaries, possibly nested.
            max_number_of_permutations_per_split (:obj:`int`, `optional`, defaults to :obj:`243`): a maximum number
                of permutations which can be generated from input sequence of tokens.
 
        Returns:
            :obj:`List[List[dict]]`: a list of smaller sequences of tokens resulting from ``tokens`` split.
        """
        splits = []
        prev_end_of_split = 0
        current_number_of_permutations = 1
        for i, token_group in enumerate(tokens):
            n = self._estimate_number_of_permutations_in_nested_dict(token_group)
            if n * current_number_of_permutations > max_number_of_permutations_per_split:
                splits.append(tokens[prev_end_of_split:i])
                prev_end_of_split = i
                current_number_of_permutations = 1
            if n > max_number_of_permutations_per_split:
                raise ValueError(
                    f"Could not split token list with respect to condition that every split can generate number of "
                    f"permutations less or equal to "
                    f"`max_number_of_permutations_per_split={max_number_of_permutations_per_split}`. "
                    f"There is an unsplittable token group that generates more than "
                    f"{max_number_of_permutations_per_split} permutations. Try to increase "
                    f"`max_number_of_permutations_per_split` parameter."
                )
            current_number_of_permutations *= n
        splits.append(tokens[prev_end_of_split:])
        assert sum([len(s) for s in splits]) == len(tokens)
        return splits
 
    def normalize(
        self,
        text: str,
        verbose: bool = False,
        punct_pre_process: bool = False,
        punct_post_process: bool = False,
    ) -> str:
        """
        Main function. Normalizes tokens from written to spoken form
            e.g. 12 kg -> twelve kilograms
 
        Args:
            text: string that may include semiotic classes
            verbose: whether to print intermediate meta information
            punct_pre_process: whether to perform punctuation pre-processing, for example, [25] -> [ 25 ]
            punct_post_process: whether to normalize punctuation
 
        Returns: spoken form
        """
        if len(text.split()) > 500:
            print(
                "WARNING! Your input is too long and could take a long time to normalize."
                "Use split_text_into_sentences() to make the input shorter and then call normalize_list()."
            )
 
        original_text = text
        if punct_pre_process:
            text = pre_process(text)
        text = text.strip()
        if not text:
            if verbose:
                print(text)
            return text
        text = pynini.escape(text)
        tagged_lattice = self.find_tags(text)
        tagged_text = self.select_tag(tagged_lattice)
        if verbose:
            print(tagged_text)
        self.parser(tagged_text)
        tokens = self.parser.parse()
        split_tokens = self._split_tokens_to_reduce_number_of_permutations(tokens)
        output = ""
        for s in split_tokens:
            tags_reordered = self.generate_permutations(s)
            verbalizer_lattice = None
            for tagged_text in tags_reordered:
                tagged_text = pynini.escape(tagged_text)
 
                verbalizer_lattice = self.find_verbalizer(tagged_text)
                if verbalizer_lattice.num_states() != 0:
                    break
            if verbalizer_lattice is None:
                raise ValueError(f"No permutations were generated from tokens {s}")
            output += " " + self.select_verbalizer(verbalizer_lattice)
        output = SPACE_DUP.sub(" ", output[1:])
 
        if self.lang == "en" and hasattr(self, "post_processor"):
            output = self.post_process(output)
 
        if punct_post_process:
            # do post-processing based on Moses detokenizer
            if self.processor:
                output = self.processor.moses_detokenizer.detokenize([output], unescape=False)
                output = post_process_punct(input=original_text, normalized_text=output)
            else:
                print("DAMO_NLP collection is not available: skipping punctuation post_processing")
 
        return output
 
    def split_text_into_sentences(self, text: str) -> List[str]:
        """
        Split text into sentences.
 
        Args:
            text: text
 
        Returns list of sentences
        """
        lower_case_unicode = ""
        upper_case_unicode = ""
        if self.lang == "ru":
            lower_case_unicode = "\u0430-\u04FF"
            upper_case_unicode = "\u0410-\u042F"
 
        # Read and split transcript by utterance (roughly, sentences)
        split_pattern = f"(?<!\w\.\w.)(?<![A-Z{upper_case_unicode}][a-z{lower_case_unicode}]+\.)(?<![A-Z{upper_case_unicode}]\.)(?<=\.|\?|\!|\.”|\?”\!”)\s(?![0-9]+[a-z]*\.)"
 
        sentences = regex.split(split_pattern, text)
        return sentences
 
    def _permute(self, d: OrderedDict) -> List[str]:
        """
        Creates reorderings of dictionary elements and serializes as strings
 
        Args:
            d: (nested) dictionary of key value pairs
 
        Return permutations of different string serializations of key value pairs
        """
        l = []
        if PRESERVE_ORDER_KEY in d.keys():
            d_permutations = [d.items()]
        else:
            d_permutations = itertools.permutations(d.items())
        for perm in d_permutations:
            subl = [""]
            for k, v in perm:
                if isinstance(v, str):
                    subl = ["".join(x) for x in itertools.product(subl, [f'{k}: "{v}" '])]
                elif isinstance(v, OrderedDict):
                    rec = self._permute(v)
                    subl = [
                        "".join(x) for x in itertools.product(subl, [f" {k} {{ "], rec, [f" }} "])
                    ]
                elif isinstance(v, bool):
                    subl = ["".join(x) for x in itertools.product(subl, [f"{k}: true "])]
                else:
                    raise ValueError()
            l.extend(subl)
        return l
 
    def generate_permutations(self, tokens: List[dict]):
        """
        Generates permutations of string serializations of list of dictionaries
 
        Args:
            tokens: list of dictionaries
 
        Returns string serialization of list of dictionaries
        """
 
        def _helper(prefix: str, tokens: List[dict], idx: int):
            """
            Generates permutations of string serializations of given dictionary
 
            Args:
                tokens: list of dictionaries
                prefix: prefix string
                idx:    index of next dictionary
 
            Returns string serialization of dictionary
            """
            if idx == len(tokens):
                yield prefix
                return
            token_options = self._permute(tokens[idx])
            for token_option in token_options:
                yield from _helper(prefix + token_option, tokens, idx + 1)
 
        return _helper("", tokens, 0)
 
    def find_tags(self, text: str) -> "pynini.FstLike":
        """
        Given text use tagger Fst to tag text
 
        Args:
            text: sentence
 
        Returns: tagged lattice
        """
        lattice = text @ self.tagger.fst
        return lattice
 
    def select_tag(self, lattice: "pynini.FstLike") -> str:
        """
        Given tagged lattice return shortest path
 
        Args:
            tagged_text: tagged text
 
        Returns: shortest path
        """
        tagged_text = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
        return tagged_text
 
    def find_verbalizer(self, tagged_text: str) -> "pynini.FstLike":
        """
        Given tagged text creates verbalization lattice
        This is context-independent.
 
        Args:
            tagged_text: input text
 
        Returns: verbalized lattice
        """
        lattice = tagged_text @ self.verbalizer.fst
        return lattice
 
    def select_verbalizer(self, lattice: "pynini.FstLike") -> str:
        """
        Given verbalized lattice return shortest path
 
        Args:
            lattice: verbalization lattice
 
        Returns: shortest path
        """
        output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
        # lattice = output @ self.verbalizer.punct_graph
        # output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
        return output
 
    def post_process(self, normalized_text: "pynini.FstLike") -> str:
        """
        Runs post processing graph on normalized text
 
        Args:
            normalized_text: normalized text
 
        Returns: shortest path
        """
        normalized_text = normalized_text.strip()
        if not normalized_text:
            return normalized_text
        normalized_text = pynini.escape(normalized_text)
 
        if self.post_processor is not None:
            normalized_text = top_rewrite(normalized_text, self.post_processor.fst)
        return normalized_text
 
 
def parse_args():
    parser = ArgumentParser()
    input = parser.add_mutually_exclusive_group()
    input.add_argument("--text", dest="input_string", help="input string", type=str)
    input.add_argument("--input_file", dest="input_file", help="input file path", type=str)
    parser.add_argument("--output_file", dest="output_file", help="output file path", type=str)
    parser.add_argument(
        "--language", help="language", choices=["en", "de", "es", "zh"], default="en", type=str
    )
    parser.add_argument(
        "--input_case",
        help="input capitalization",
        choices=["lower_cased", "cased"],
        default="cased",
        type=str,
    )
    parser.add_argument("--verbose", help="print info for debugging", action="store_true")
    parser.add_argument(
        "--punct_post_process",
        help="set to True to enable punctuation post processing to match input.",
        action="store_true",
    )
    parser.add_argument(
        "--punct_pre_process",
        help="set to True to enable punctuation pre processing",
        action="store_true",
    )
    parser.add_argument(
        "--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true"
    )
    parser.add_argument(
        "--whitelist", help="path to a file with with whitelist", default=None, type=str
    )
    parser.add_argument(
        "--cache_dir",
        help="path to a dir with .far grammar file. Set to None to avoid using cache",
        default=None,
        type=str,
    )
    return parser.parse_args()
 
 
if __name__ == "__main__":
    start_time = perf_counter()
 
    args = parse_args()
    whitelist = os.path.abspath(args.whitelist) if args.whitelist else None
 
    if not args.input_string and not args.input_file:
        raise ValueError("Either `--text` or `--input_file` required")
 
    normalizer = Normalizer(
        input_case=args.input_case,
        cache_dir=args.cache_dir,
        overwrite_cache=args.overwrite_cache,
        whitelist=whitelist,
        lang=args.language,
    )
    if args.input_string:
        print(
            normalizer.normalize(
                args.input_string,
                verbose=args.verbose,
                punct_pre_process=args.punct_pre_process,
                punct_post_process=args.punct_post_process,
            )
        )
    elif args.input_file:
        print("Loading data: " + args.input_file)
        data = load_file(args.input_file)
 
        print("- Data: " + str(len(data)) + " sentences")
        normalizer_prediction = normalizer.normalize_list(
            data,
            verbose=args.verbose,
            punct_pre_process=args.punct_pre_process,
            punct_post_process=args.punct_post_process,
        )
        if args.output_file:
            write_file(args.output_file, normalizer_prediction)
            print(f"- Normalized. Writing out to {args.output_file}")
        else:
            print(normalizer_prediction)
 
    print(f"Execution time: {perf_counter() - start_time:.02f} sec")