import json import os import time from argparse import ArgumentParser from glob import glob from typing import List, Tuple import pynini from joblib import Parallel, delayed from fun_text_processing.text_normalization.data_loader_utils import post_process_punct, pre_process from fun_text_processing.text_normalization.normalize import Normalizer from pynini.lib import rewrite from tqdm import tqdm try: from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.models import ASRModel ASR_AVAILABLE = True except (ModuleNotFoundError, ImportError): ASR_AVAILABLE = False """ The script provides multiple normalization options and chooses the best one that minimizes CER of the ASR output (most of the semiotic classes use deterministic=False flag). To run this script with a .json manifest file, the manifest file should contain the following fields: "audio_data" - path to the audio file "text" - raw text "pred_text" - ASR model prediction See https://github.com/NVIDIA/NeMo/blob/main/examples/asr/transcribe_speech.py on how to add ASR predictions When the manifest is ready, run: python normalize_with_audio.py \ --audio_data PATH/TO/MANIFEST.JSON \ --language en To run with a single audio file, specify path to audio and text with: python normalize_with_audio.py \ --audio_data PATH/TO/AUDIO.WAV \ --language en \ --text raw text OR PATH/TO/.TXT/FILE --model QuartzNet15x5Base-En \ --verbose To see possible normalization options for a text input without an audio file (could be used for debugging), run: python python normalize_with_audio.py --text "RAW TEXT" Specify `--cache_dir` to generate .far grammars once and re-used them for faster inference """ class NormalizerWithAudio(Normalizer): """ Normalizer class that converts text from written to spoken form. Useful for TTS preprocessing. Args: input_case: expected input capitalization lang: language cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache. overwrite_cache: set to True to overwrite .far files whitelist: path to a file with whitelist replacements post_process: WFST-based post processing, e.g. to remove extra spaces added during TN. Note: punct_post_process flag in normalize() supports all languages. """ def __init__( self, input_case: str, lang: str = "en", cache_dir: str = None, overwrite_cache: bool = False, whitelist: str = None, lm: bool = False, post_process: bool = True, ): super().__init__( input_case=input_case, lang=lang, deterministic=False, cache_dir=cache_dir, overwrite_cache=overwrite_cache, whitelist=whitelist, lm=lm, post_process=post_process, ) self.lm = lm def normalize( self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False, ) -> str: """ Main function. Normalizes tokens from written to spoken form e.g. 12 kg -> twelve kilograms Args: text: string that may include semiotic classes n_tagged: number of tagged options to consider, -1 - to get all possible tagged options punct_post_process: whether to normalize punctuation verbose: whether to print intermediate meta information Returns: normalized text options (usually there are multiple ways of normalizing a given semiotic class) """ if len(text.split()) > 500: raise ValueError( "Your input is too long. Please split up the input into sentences, " "or strings with fewer than 500 words" ) original_text = text text = pre_process(text) # to handle [] text = text.strip() if not text: if verbose: print(text) return text text = pynini.escape(text) print(text) if self.lm: if self.lang not in ["en"]: raise ValueError(f"{self.lang} is not supported in LM mode") if self.lang == "en": # this to keep arpabet phonemes in the list of options if "[" in text and "]" in text: lattice = rewrite.rewrite_lattice(text, self.tagger.fst) else: try: lattice = rewrite.rewrite_lattice(text, self.tagger.fst_no_digits) except pynini.lib.rewrite.Error: lattice = rewrite.rewrite_lattice(text, self.tagger.fst) lattice = rewrite.lattice_to_nshortest(lattice, n_tagged) tagged_texts = [(x[1], float(x[2])) for x in lattice.paths().items()] tagged_texts.sort(key=lambda x: x[1]) tagged_texts, weights = list(zip(*tagged_texts)) else: tagged_texts = self._get_tagged_text(text, n_tagged) # non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between if self.lang == "en": normalized_texts = tagged_texts normalized_texts = [self.post_process(text) for text in normalized_texts] else: normalized_texts = [] for tagged_text in tagged_texts: self._verbalize(tagged_text, normalized_texts, verbose=verbose) if len(normalized_texts) == 0: raise ValueError() if punct_post_process: # do post-processing based on Moses detokenizer if self.processor: normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts] normalized_texts = [ post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts ] if self.lm: remove_dup = sorted(list(set(zip(normalized_texts, weights))), key=lambda x: x[1]) normalized_texts, weights = zip(*remove_dup) return list(normalized_texts), weights normalized_texts = set(normalized_texts) return normalized_texts def _get_tagged_text(self, text, n_tagged): """ Returns text after tokenize and classify Args; text: input text n_tagged: number of tagged options to consider, -1 - return all possible tagged options """ if n_tagged == -1: if self.lang == "en": # this to keep arpabet phonemes in the list of options if "[" in text and "]" in text: tagged_texts = rewrite.rewrites(text, self.tagger.fst) else: try: tagged_texts = rewrite.rewrites(text, self.tagger.fst_no_digits) except pynini.lib.rewrite.Error: tagged_texts = rewrite.rewrites(text, self.tagger.fst) else: tagged_texts = rewrite.rewrites(text, self.tagger.fst) else: if self.lang == "en": # this to keep arpabet phonemes in the list of options if "[" in text and "]" in text: tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) else: try: # try self.tagger graph that produces output without digits tagged_texts = rewrite.top_rewrites( text, self.tagger.fst_no_digits, nshortest=n_tagged ) except pynini.lib.rewrite.Error: tagged_texts = rewrite.top_rewrites( text, self.tagger.fst, nshortest=n_tagged ) else: tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) return tagged_texts def _verbalize(self, tagged_text: str, normalized_texts: List[str], verbose: bool = False): """ Verbalizes tagged text Args: tagged_text: text with tags normalized_texts: list of possible normalization options verbose: if true prints intermediate classification results """ def get_verbalized_text(tagged_text): return rewrite.rewrites(tagged_text, self.verbalizer.fst) self.parser(tagged_text) tokens = self.parser.parse() tags_reordered = self.generate_permutations(tokens) for tagged_text_reordered in tags_reordered: try: tagged_text_reordered = pynini.escape(tagged_text_reordered) normalized_texts.extend(get_verbalized_text(tagged_text_reordered)) if verbose: print(tagged_text_reordered) except pynini.lib.rewrite.Error: continue def select_best_match( self, normalized_texts: List[str], input_text: str, pred_text: str, verbose: bool = False, remove_punct: bool = False, cer_threshold: int = 100, ): """ Selects the best normalization option based on the lowest CER Args: normalized_texts: normalized text options input_text: input text pred_text: ASR model transcript of the audio file corresponding to the normalized text verbose: whether to print intermediate meta information remove_punct: whether to remove punctuation before calculating CER cer_threshold: if CER for pred_text is above the cer_threshold, no normalization will be performed Returns: normalized text with the lowest CER and CER value """ if pred_text == "": return input_text, cer_threshold normalized_texts_cer = calculate_cer(normalized_texts, pred_text, remove_punct) normalized_texts_cer = sorted(normalized_texts_cer, key=lambda x: x[1]) normalized_text, cer = normalized_texts_cer[0] if cer > cer_threshold: return input_text, cer if verbose: print("-" * 30) for option in normalized_texts: print(option) print("-" * 30) return normalized_text, cer def calculate_cer( normalized_texts: List[str], pred_text: str, remove_punct=False ) -> List[Tuple[str, float]]: """ Calculates character error rate (CER) Args: normalized_texts: normalized text options pred_text: ASR model output Returns: normalized options with corresponding CER """ normalized_options = [] for text in normalized_texts: text_clean = text.replace("-", " ").lower() if remove_punct: for punct in "!?:;,.-()*+-/<=>@^_": text_clean = text_clean.replace(punct, "") cer = round(word_error_rate([pred_text], [text_clean], use_cer=True) * 100, 2) normalized_options.append((text, cer)) return normalized_options def get_asr_model(asr_model): """ Returns ASR Model Args: asr_model: NeMo ASR model """ if os.path.exists(args.model): asr_model = ASRModel.restore_from(asr_model) elif args.model in ASRModel.get_available_model_names(): asr_model = ASRModel.from_pretrained(asr_model) else: raise ValueError( f"Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}" ) return asr_model def parse_args(): parser = ArgumentParser() parser.add_argument( "--text", help="input string or path to a .txt file", default=None, type=str ) parser.add_argument( "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str, ) parser.add_argument( "--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str, ) parser.add_argument( "--audio_data", default=None, help="path to an audio file or .json manifest" ) parser.add_argument( "--model", type=str, default="QuartzNet15x5Base-En", help="Pre-trained model name or path to model checkpoint", ) parser.add_argument( "--n_tagged", type=int, default=30, help="number of tagged options to consider, -1 - return all possible tagged options", ) parser.add_argument("--verbose", help="print info for debugging", action="store_true") parser.add_argument( "--no_remove_punct_for_cer", help="Set to True to NOT remove punctuation before calculating CER", action="store_true", ) parser.add_argument( "--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true", ) parser.add_argument( "--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true" ) parser.add_argument( "--whitelist", help="path to a file with with whitelist", default=None, type=str ) parser.add_argument( "--cache_dir", help="path to a dir with .far grammar file. Set to None to avoid using cache", default=None, type=str, ) parser.add_argument( "--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs" ) parser.add_argument( "--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now.", ) parser.add_argument( "--cer_threshold", default=100, type=int, help="if CER for pred_text is above the cer_threshold, no normalization will be performed", ) parser.add_argument( "--batch_size", default=200, type=int, help="Number of examples for each process" ) return parser.parse_args() def _normalize_line( normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold, ): line = json.loads(line) pred_text = line["pred_text"] normalized_texts = normalizer.normalize( text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process, ) normalized_texts = set(normalized_texts) normalized_text, cer = normalizer.select_best_match( normalized_texts=normalized_texts, input_text=line["text"], pred_text=pred_text, verbose=verbose, remove_punct=remove_punct, cer_threshold=cer_threshold, ) line["nemo_normalized"] = normalized_text line["CER_nemo_normalized"] = cer return line def normalize_manifest( normalizer, audio_data: str, n_jobs: int, n_tagged: int, remove_punct: bool, punct_post_process: bool, batch_size: int, cer_threshold: int, ): """ Args: args.audio_data: path to .json manifest file. """ def __process_batch(batch_idx: int, batch: List[str], dir_name: str): """ Normalizes batch of text sequences Args: batch: list of texts batch_idx: batch index dir_name: path to output directory to save results """ normalized_lines = [ _normalize_line( normalizer, n_tagged, verbose=False, line=line, remove_punct=remove_punct, punct_post_process=punct_post_process, cer_threshold=cer_threshold, ) for line in tqdm(batch) ] with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out: for line in normalized_lines: f_out.write(json.dumps(line, ensure_ascii=False) + "\n") print(f"Batch -- {batch_idx} -- is complete") manifest_out = audio_data.replace(".json", "_normalized.json") with open(audio_data, "r") as f: lines = f.readlines() print(f"Normalizing {len(lines)} lines of {audio_data}...") # to save intermediate results to a file batch = min(len(lines), batch_size) tmp_dir = manifest_out.replace(".json", "_parts") os.makedirs(tmp_dir, exist_ok=True) Parallel(n_jobs=n_jobs)( delayed(__process_batch)(idx, lines[i : i + batch], tmp_dir) for idx, i in enumerate(range(0, len(lines), batch)) ) # aggregate all intermediate files with open(manifest_out, "w") as f_out: for batch_f in sorted(glob(f"{tmp_dir}/*.json")): with open(batch_f, "r") as f_in: lines = f_in.read() f_out.write(lines) print(f"Normalized version saved at {manifest_out}") if __name__ == "__main__": args = parse_args() if not ASR_AVAILABLE and args.audio_data: raise ValueError("NeMo ASR collection is not installed.") start = time.time() args.whitelist = os.path.abspath(args.whitelist) if args.whitelist else None if args.text is not None: normalizer = NormalizerWithAudio( input_case=args.input_case, lang=args.language, cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache, whitelist=args.whitelist, lm=args.lm, ) if os.path.exists(args.text): with open(args.text, "r") as f: args.text = f.read().strip() normalized_texts = normalizer.normalize( text=args.text, verbose=args.verbose, n_tagged=args.n_tagged, punct_post_process=not args.no_punct_post_process, ) if not normalizer.lm: normalized_texts = set(normalized_texts) if args.audio_data: asr_model = get_asr_model(args.model) pred_text = asr_model.transcribe([args.audio_data])[0] normalized_text, cer = normalizer.select_best_match( normalized_texts=normalized_texts, pred_text=pred_text, input_text=args.text, verbose=args.verbose, remove_punct=not args.no_remove_punct_for_cer, cer_threshold=args.cer_threshold, ) print(f"Transcript: {pred_text}") print(f"Normalized: {normalized_text}") else: print("Normalization options:") for norm_text in normalized_texts: print(norm_text) elif not os.path.exists(args.audio_data): raise ValueError(f"{args.audio_data} not found.") elif args.audio_data.endswith(".json"): normalizer = NormalizerWithAudio( input_case=args.input_case, lang=args.language, cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache, whitelist=args.whitelist, ) normalize_manifest( normalizer=normalizer, audio_data=args.audio_data, n_jobs=args.n_jobs, n_tagged=args.n_tagged, remove_punct=not args.no_remove_punct_for_cer, punct_post_process=not args.no_punct_post_process, batch_size=args.batch_size, cer_threshold=args.cer_threshold, ) else: raise ValueError( "Provide either path to .json manifest in '--audio_data' OR " + "'--audio_data' path to audio file and '--text' path to a text file OR" "'--text' string text (for debugging without audio)" ) print(f"Execution time: {round((time.time() - start)/60, 2)} min.")