Legend
2024-12-15 5e7a8d1ccae80e54f2e2ecfffdf8e4294800b5c3
fun_text_processing/text_normalization/normalize.py
@@ -1,5 +1,3 @@
import itertools
import os
import re
@@ -29,7 +27,7 @@
except (ModuleNotFoundError, ImportError) as e:
    NLP_AVAILABLE = False
SPACE_DUP = re.compile(' {2,}')
SPACE_DUP = re.compile(" {2,}")
class Normalizer:
@@ -50,7 +48,7 @@
    def __init__(
        self,
        input_case: str,
        lang: str = 'en',
        lang: str = "en",
        deterministic: bool = True,
        cache_dir: str = None,
        overwrite_cache: bool = False,
@@ -63,36 +61,62 @@
        self.post_processor = None
        if lang == "en":
            from fun_text_processing.text_normalization.en.verbalizers.verbalize_final import VerbalizeFinalFst
            from fun_text_processing.text_normalization.en.verbalizers.post_processing import PostProcessingFst
            from fun_text_processing.text_normalization.en.verbalizers.verbalize_final import (
                VerbalizeFinalFst,
            )
            from fun_text_processing.text_normalization.en.verbalizers.post_processing import (
                PostProcessingFst,
            )
            if post_process:
                self.post_processor = PostProcessingFst(cache_dir=cache_dir, overwrite_cache=overwrite_cache)
                self.post_processor = PostProcessingFst(
                    cache_dir=cache_dir, overwrite_cache=overwrite_cache
                )
            if deterministic:
                from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
                from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify import (
                    ClassifyFst,
                )
            else:
                if lm:
                    from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify_lm import ClassifyFst
                    from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify_lm import (
                        ClassifyFst,
                    )
                else:
                    from fun_text_processing.text_normalization.en.taggers.tokenize_and_classify_with_audio import (
                        ClassifyFst,
                    )
        elif lang == 'ru':
        elif lang == "ru":
            # Ru TN only support non-deterministic cases and produces multiple normalization options
            # use normalize_with_audio.py
            from fun_text_processing.text_normalization.ru.taggers.tokenize_and_classify import ClassifyFst
            from fun_text_processing.text_normalization.ru.verbalizers.verbalize_final import VerbalizeFinalFst
        elif lang == 'de':
            from fun_text_processing.text_normalization.de.taggers.tokenize_and_classify import ClassifyFst
            from fun_text_processing.text_normalization.de.verbalizers.verbalize_final import VerbalizeFinalFst
        elif lang == 'es':
            from fun_text_processing.text_normalization.es.taggers.tokenize_and_classify import ClassifyFst
            from fun_text_processing.text_normalization.es.verbalizers.verbalize_final import VerbalizeFinalFst
        elif lang == 'zh':
            from fun_text_processing.text_normalization.zh.taggers.tokenize_and_classify import ClassifyFst
            from fun_text_processing.text_normalization.zh.verbalizers.verbalize_final import VerbalizeFinalFst
            from fun_text_processing.text_normalization.ru.taggers.tokenize_and_classify import (
                ClassifyFst,
            )
            from fun_text_processing.text_normalization.ru.verbalizers.verbalize_final import (
                VerbalizeFinalFst,
            )
        elif lang == "de":
            from fun_text_processing.text_normalization.de.taggers.tokenize_and_classify import (
                ClassifyFst,
            )
            from fun_text_processing.text_normalization.de.verbalizers.verbalize_final import (
                VerbalizeFinalFst,
            )
        elif lang == "es":
            from fun_text_processing.text_normalization.es.taggers.tokenize_and_classify import (
                ClassifyFst,
            )
            from fun_text_processing.text_normalization.es.verbalizers.verbalize_final import (
                VerbalizeFinalFst,
            )
        elif lang == "zh":
            from fun_text_processing.text_normalization.zh.taggers.tokenize_and_classify import (
                ClassifyFst,
            )
            from fun_text_processing.text_normalization.zh.verbalizers.verbalize_final import (
                VerbalizeFinalFst,
            )
        self.tagger = ClassifyFst(
            input_case=input_case,
            deterministic=deterministic,
@@ -144,7 +168,9 @@
        try:
            normalized_texts = Parallel(n_jobs=n_jobs)(
                delayed(self.process_batch)(texts[i : i + batch], verbose, punct_pre_process, punct_post_process)
                delayed(self.process_batch)(
                    texts[i : i + batch], verbose, punct_pre_process, punct_post_process
                )
                for i in range(0, len(texts), batch)
            )
        except BaseException as e:
@@ -164,7 +190,10 @@
        """
        normalized_lines = [
            self.normalize(
                text, verbose=verbose, punct_pre_process=punct_pre_process, punct_post_process=punct_post_process
                text,
                verbose=verbose,
                punct_pre_process=punct_pre_process,
                punct_post_process=punct_post_process,
            )
            for text in tqdm(batch)
        ]
@@ -238,7 +267,11 @@
        return splits
    def normalize(
        self, text: str, verbose: bool = False, punct_pre_process: bool = False, punct_post_process: bool = False
        self,
        text: str,
        verbose: bool = False,
        punct_pre_process: bool = False,
        punct_post_process: bool = False,
    ) -> str:
        """
        Main function. Normalizes tokens from written to spoken form
@@ -286,10 +319,10 @@
                    break
            if verbalizer_lattice is None:
                raise ValueError(f"No permutations were generated from tokens {s}")
            output += ' ' + self.select_verbalizer(verbalizer_lattice)
        output = SPACE_DUP.sub(' ', output[1:])
            output += " " + self.select_verbalizer(verbalizer_lattice)
        output = SPACE_DUP.sub(" ", output[1:])
        if self.lang == "en" and hasattr(self, 'post_processor'):
        if self.lang == "en" and hasattr(self, "post_processor"):
            output = self.post_process(output)
        if punct_post_process:
@@ -311,11 +344,11 @@
        Returns list of sentences
        """
        lower_case_unicode = ''
        upper_case_unicode = ''
        lower_case_unicode = ""
        upper_case_unicode = ""
        if self.lang == "ru":
            lower_case_unicode = '\u0430-\u04FF'
            upper_case_unicode = '\u0410-\u042F'
            lower_case_unicode = "\u0430-\u04FF"
            upper_case_unicode = "\u0410-\u042F"
        # Read and split transcript by utterance (roughly, sentences)
        split_pattern = f"(?<!\w\.\w.)(?<![A-Z{upper_case_unicode}][a-z{lower_case_unicode}]+\.)(?<![A-Z{upper_case_unicode}]\.)(?<=\.|\?|\!|\.”|\?”\!”)\s(?![0-9]+[a-z]*\.)"
@@ -341,10 +374,12 @@
            subl = [""]
            for k, v in perm:
                if isinstance(v, str):
                    subl = ["".join(x) for x in itertools.product(subl, [f"{k}: \"{v}\" "])]
                    subl = ["".join(x) for x in itertools.product(subl, [f'{k}: "{v}" '])]
                elif isinstance(v, OrderedDict):
                    rec = self._permute(v)
                    subl = ["".join(x) for x in itertools.product(subl, [f" {k} {{ "], rec, [f" }} "])]
                    subl = [
                        "".join(x) for x in itertools.product(subl, [f" {k} {{ "], rec, [f" }} "])
                    ]
                elif isinstance(v, bool):
                    subl = ["".join(x) for x in itertools.product(subl, [f"{k}: true "])]
                else:
@@ -382,7 +417,7 @@
        return _helper("", tokens, 0)
    def find_tags(self, text: str) -> 'pynini.FstLike':
    def find_tags(self, text: str) -> "pynini.FstLike":
        """
        Given text use tagger Fst to tag text
@@ -394,7 +429,7 @@
        lattice = text @ self.tagger.fst
        return lattice
    def select_tag(self, lattice: 'pynini.FstLike') -> str:
    def select_tag(self, lattice: "pynini.FstLike") -> str:
        """
        Given tagged lattice return shortest path
@@ -406,7 +441,7 @@
        tagged_text = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
        return tagged_text
    def find_verbalizer(self, tagged_text: str) -> 'pynini.FstLike':
    def find_verbalizer(self, tagged_text: str) -> "pynini.FstLike":
        """
        Given tagged text creates verbalization lattice
        This is context-independent.
@@ -419,7 +454,7 @@
        lattice = tagged_text @ self.verbalizer.fst
        return lattice
    def select_verbalizer(self, lattice: 'pynini.FstLike') -> str:
    def select_verbalizer(self, lattice: "pynini.FstLike") -> str:
        """
        Given verbalized lattice return shortest path
@@ -433,7 +468,7 @@
        # output = pynini.shortestpath(lattice, nshortest=1, unique=True).string()
        return output
    def post_process(self, normalized_text: 'pynini.FstLike') -> str:
    def post_process(self, normalized_text: "pynini.FstLike") -> str:
        """
        Runs post processing graph on normalized text
@@ -457,22 +492,34 @@
    input = parser.add_mutually_exclusive_group()
    input.add_argument("--text", dest="input_string", help="input string", type=str)
    input.add_argument("--input_file", dest="input_file", help="input file path", type=str)
    parser.add_argument('--output_file', dest="output_file", help="output file path", type=str)
    parser.add_argument("--language", help="language", choices=["en", "de", "es", "zh"], default="en", type=str)
    parser.add_argument("--output_file", dest="output_file", help="output file path", type=str)
    parser.add_argument(
        "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
        "--language", help="language", choices=["en", "de", "es", "zh"], default="en", type=str
    )
    parser.add_argument("--verbose", help="print info for debugging", action='store_true')
    parser.add_argument(
        "--input_case",
        help="input capitalization",
        choices=["lower_cased", "cased"],
        default="cased",
        type=str,
    )
    parser.add_argument("--verbose", help="print info for debugging", action="store_true")
    parser.add_argument(
        "--punct_post_process",
        help="set to True to enable punctuation post processing to match input.",
        action="store_true",
    )
    parser.add_argument(
        "--punct_pre_process", help="set to True to enable punctuation pre processing", action="store_true"
        "--punct_pre_process",
        help="set to True to enable punctuation pre processing",
        action="store_true",
    )
    parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
    parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
    parser.add_argument(
        "--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true"
    )
    parser.add_argument(
        "--whitelist", help="path to a file with with whitelist", default=None, type=str
    )
    parser.add_argument(
        "--cache_dir",
        help="path to a dir with .far grammar file. Set to None to avoid using cache",