#!/usr/bin/python
|
# -*- coding: utf-8 -*-
|
|
from argparse import ArgumentParser
|
from time import perf_counter
|
from typing import List
|
|
from fun_text_processing.text_normalization.data_loader_utils import load_file, write_file
|
from fun_text_processing.text_normalization.normalize import Normalizer
|
from fun_text_processing.text_normalization.token_parser import TokenParser
|
|
|
class InverseNormalizer(Normalizer):
|
"""
|
Inverse normalizer that converts text from spoken to written form. Useful for ASR postprocessing.
|
Input is expected to have no punctuation outside of approstrophe (') and dash (-) and be lower cased.
|
|
Args:
|
lang: language specifying the ITN
|
cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
|
overwrite_cache: set to True to overwrite .far files
|
"""
|
|
def __init__(
|
self,
|
lang: str = "en",
|
cache_dir: str = None,
|
overwrite_cache: bool = False,
|
enable_standalone_number: bool = True,
|
enable_0_to_9: bool = True,
|
):
|
|
if lang == "en":
|
from fun_text_processing.inverse_text_normalization.en.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
|
elif lang == "id":
|
from fun_text_processing.inverse_text_normalization.id.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.id.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
|
elif lang == "ja":
|
from fun_text_processing.inverse_text_normalization.ja.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.ja.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
|
elif lang == "es":
|
from fun_text_processing.inverse_text_normalization.es.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.es.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
|
elif lang == "pt":
|
from fun_text_processing.inverse_text_normalization.pt.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.pt.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
|
elif lang == "ru":
|
from fun_text_processing.inverse_text_normalization.ru.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.ru.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
|
elif lang == "de":
|
from fun_text_processing.inverse_text_normalization.de.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.de.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
elif lang == "fr":
|
from fun_text_processing.inverse_text_normalization.fr.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.fr.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
elif lang == "vi":
|
from fun_text_processing.inverse_text_normalization.vi.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.vi.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
elif lang == "ko":
|
from fun_text_processing.inverse_text_normalization.ko.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.ko.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
elif lang == "zh":
|
from fun_text_processing.inverse_text_normalization.zh.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.zh.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
elif lang == "tl":
|
from fun_text_processing.inverse_text_normalization.tl.taggers.tokenize_and_classify import (
|
ClassifyFst,
|
)
|
from fun_text_processing.inverse_text_normalization.tl.verbalizers.verbalize_final import (
|
VerbalizeFinalFst,
|
)
|
|
self.tagger = ClassifyFst(cache_dir=cache_dir, overwrite_cache=overwrite_cache)
|
self.verbalizer = VerbalizeFinalFst()
|
self.parser = TokenParser()
|
self.lang = lang
|
self.convert_number = enable_standalone_number
|
self.enable_0_to_9 = enable_0_to_9
|
|
def inverse_normalize_list(self, texts: List[str], verbose=False) -> List[str]:
|
"""
|
NeMo inverse text normalizer
|
|
Args:
|
texts: list of input strings
|
verbose: whether to print intermediate meta information
|
|
Returns converted list of input strings
|
"""
|
# print(texts)
|
return self.normalize_list(texts=texts, verbose=verbose)
|
|
def inverse_normalize(self, text: str, verbose: bool) -> str:
|
"""
|
Main function. Inverse normalizes tokens from spoken to written form
|
e.g. twelve kilograms -> 12 kg
|
|
Args:
|
text: string that may include semiotic classes
|
verbose: whether to print intermediate meta information
|
|
Returns: written form
|
"""
|
print(text)
|
return self.normalize(text=text, verbose=verbose)
|
|
|
def str2bool(s, default=False):
|
s = s.lower()
|
if s == "true":
|
return True
|
elif s == "false":
|
return False
|
else:
|
return default
|
|
|
def parse_args():
|
parser = ArgumentParser()
|
input = parser.add_mutually_exclusive_group()
|
input.add_argument("--text", dest="input_string", help="input string", type=str)
|
input.add_argument("--input_file", dest="input_file", help="input file path", type=str)
|
parser.add_argument("--output_file", dest="output_file", help="output file path", type=str)
|
parser.add_argument(
|
"--language",
|
help="language",
|
choices=["en", "id", "ja", "de", "es", "pt", "ru", "fr", "vi", "ko", "zh", "tl"],
|
default="en",
|
type=str,
|
)
|
parser.add_argument("--verbose", help="print info for debugging", action="store_true")
|
parser.add_argument(
|
"--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true"
|
)
|
parser.add_argument(
|
"--cache_dir",
|
help="path to a dir with .far grammar file. Set to None to avoid using cache",
|
default=None,
|
type=str,
|
)
|
parser.add_argument(
|
"--enable_standalone_number", type=str, default="True", help="enable standalone number"
|
)
|
parser.add_argument(
|
"--enable_0_to_9", type=str, default="True", help="enable convert number 0 to 9"
|
)
|
return parser.parse_args()
|
|
|
if __name__ == "__main__":
|
args = parse_args()
|
start_time = perf_counter()
|
if args.language == "ja":
|
inverse_normalizer = InverseNormalizer(
|
lang=args.language,
|
cache_dir=args.cache_dir,
|
overwrite_cache=args.overwrite_cache,
|
enable_standalone_number=str2bool(args.enable_standalone_number),
|
enable_0_to_9=str2bool(args.enable_0_to_9),
|
)
|
else:
|
inverse_normalizer = InverseNormalizer(
|
lang=args.language, cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache
|
)
|
print(f"Time to generate graph: {round(perf_counter() - start_time, 2)} sec")
|
|
if args.input_string:
|
print(inverse_normalizer.inverse_normalize(args.input_string, verbose=args.verbose))
|
elif args.input_file:
|
print("Loading data: " + args.input_file)
|
data = load_file(args.input_file)
|
|
print("- Data: " + str(len(data)) + " sentences")
|
prediction = inverse_normalizer.inverse_normalize_list(data, verbose=args.verbose)
|
if args.output_file:
|
write_file(args.output_file, prediction)
|
print(f"- Denormalized. Writing out to {args.output_file}")
|
else:
|
print(prediction)
|