# Copyright NeMo (https://github.com/NVIDIA/NeMo). All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import time from argparse import ArgumentParser from glob import glob from typing import List, Tuple import pynini from joblib import Parallel, delayed from fun_text_processing.text_normalization.data_loader_utils import post_process_punct, pre_process from fun_text_processing.text_normalization.normalize import Normalizer from pynini.lib import rewrite from tqdm import tqdm try: from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.models import ASRModel ASR_AVAILABLE = True except (ModuleNotFoundError, ImportError): ASR_AVAILABLE = False """ The script provides multiple normalization options and chooses the best one that minimizes CER of the ASR output (most of the semiotic classes use deterministic=False flag). To run this script with a .json manifest file, the manifest file should contain the following fields: "audio_data" - path to the audio file "text" - raw text "pred_text" - ASR model prediction See https://github.com/NVIDIA/NeMo/blob/main/examples/asr/transcribe_speech.py on how to add ASR predictions When the manifest is ready, run: python normalize_with_audio.py \ --audio_data PATH/TO/MANIFEST.JSON \ --language en To run with a single audio file, specify path to audio and text with: python normalize_with_audio.py \ --audio_data PATH/TO/AUDIO.WAV \ --language en \ --text raw text OR PATH/TO/.TXT/FILE --model QuartzNet15x5Base-En \ --verbose To see possible normalization options for a text input without an audio file (could be used for debugging), run: python python normalize_with_audio.py --text "RAW TEXT" Specify `--cache_dir` to generate .far grammars once and re-used them for faster inference """ class NormalizerWithAudio(Normalizer): """ Normalizer class that converts text from written to spoken form. Useful for TTS preprocessing. Args: input_case: expected input capitalization lang: language cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache. overwrite_cache: set to True to overwrite .far files whitelist: path to a file with whitelist replacements post_process: WFST-based post processing, e.g. to remove extra spaces added during TN. Note: punct_post_process flag in normalize() supports all languages. """ def __init__( self, input_case: str, lang: str = 'en', cache_dir: str = None, overwrite_cache: bool = False, whitelist: str = None, lm: bool = False, post_process: bool = True, ): super().__init__( input_case=input_case, lang=lang, deterministic=False, cache_dir=cache_dir, overwrite_cache=overwrite_cache, whitelist=whitelist, lm=lm, post_process=post_process, ) self.lm = lm def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False,) -> str: """ Main function. Normalizes tokens from written to spoken form e.g. 12 kg -> twelve kilograms Args: text: string that may include semiotic classes n_tagged: number of tagged options to consider, -1 - to get all possible tagged options punct_post_process: whether to normalize punctuation verbose: whether to print intermediate meta information Returns: normalized text options (usually there are multiple ways of normalizing a given semiotic class) """ if len(text.split()) > 500: raise ValueError( "Your input is too long. Please split up the input into sentences, " "or strings with fewer than 500 words" ) original_text = text text = pre_process(text) # to handle [] text = text.strip() if not text: if verbose: print(text) return text text = pynini.escape(text) print(text) if self.lm: if self.lang not in ["en"]: raise ValueError(f"{self.lang} is not supported in LM mode") if self.lang == "en": # this to keep arpabet phonemes in the list of options if "[" in text and "]" in text: lattice = rewrite.rewrite_lattice(text, self.tagger.fst) else: try: lattice = rewrite.rewrite_lattice(text, self.tagger.fst_no_digits) except pynini.lib.rewrite.Error: lattice = rewrite.rewrite_lattice(text, self.tagger.fst) lattice = rewrite.lattice_to_nshortest(lattice, n_tagged) tagged_texts = [(x[1], float(x[2])) for x in lattice.paths().items()] tagged_texts.sort(key=lambda x: x[1]) tagged_texts, weights = list(zip(*tagged_texts)) else: tagged_texts = self._get_tagged_text(text, n_tagged) # non-deterministic Eng normalization uses tagger composed with verbalizer, no permutation in between if self.lang == "en": normalized_texts = tagged_texts normalized_texts = [self.post_process(text) for text in normalized_texts] else: normalized_texts = [] for tagged_text in tagged_texts: self._verbalize(tagged_text, normalized_texts, verbose=verbose) if len(normalized_texts) == 0: raise ValueError() if punct_post_process: # do post-processing based on Moses detokenizer if self.processor: normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts] normalized_texts = [ post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts ] if self.lm: remove_dup = sorted(list(set(zip(normalized_texts, weights))), key=lambda x: x[1]) normalized_texts, weights = zip(*remove_dup) return list(normalized_texts), weights normalized_texts = set(normalized_texts) return normalized_texts def _get_tagged_text(self, text, n_tagged): """ Returns text after tokenize and classify Args; text: input text n_tagged: number of tagged options to consider, -1 - return all possible tagged options """ if n_tagged == -1: if self.lang == "en": # this to keep arpabet phonemes in the list of options if "[" in text and "]" in text: tagged_texts = rewrite.rewrites(text, self.tagger.fst) else: try: tagged_texts = rewrite.rewrites(text, self.tagger.fst_no_digits) except pynini.lib.rewrite.Error: tagged_texts = rewrite.rewrites(text, self.tagger.fst) else: tagged_texts = rewrite.rewrites(text, self.tagger.fst) else: if self.lang == "en": # this to keep arpabet phonemes in the list of options if "[" in text and "]" in text: tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) else: try: # try self.tagger graph that produces output without digits tagged_texts = rewrite.top_rewrites(text, self.tagger.fst_no_digits, nshortest=n_tagged) except pynini.lib.rewrite.Error: tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) else: tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) return tagged_texts def _verbalize(self, tagged_text: str, normalized_texts: List[str], verbose: bool = False): """ Verbalizes tagged text Args: tagged_text: text with tags normalized_texts: list of possible normalization options verbose: if true prints intermediate classification results """ def get_verbalized_text(tagged_text): return rewrite.rewrites(tagged_text, self.verbalizer.fst) self.parser(tagged_text) tokens = self.parser.parse() tags_reordered = self.generate_permutations(tokens) for tagged_text_reordered in tags_reordered: try: tagged_text_reordered = pynini.escape(tagged_text_reordered) normalized_texts.extend(get_verbalized_text(tagged_text_reordered)) if verbose: print(tagged_text_reordered) except pynini.lib.rewrite.Error: continue def select_best_match( self, normalized_texts: List[str], input_text: str, pred_text: str, verbose: bool = False, remove_punct: bool = False, cer_threshold: int = 100, ): """ Selects the best normalization option based on the lowest CER Args: normalized_texts: normalized text options input_text: input text pred_text: ASR model transcript of the audio file corresponding to the normalized text verbose: whether to print intermediate meta information remove_punct: whether to remove punctuation before calculating CER cer_threshold: if CER for pred_text is above the cer_threshold, no normalization will be performed Returns: normalized text with the lowest CER and CER value """ if pred_text == "": return input_text, cer_threshold normalized_texts_cer = calculate_cer(normalized_texts, pred_text, remove_punct) normalized_texts_cer = sorted(normalized_texts_cer, key=lambda x: x[1]) normalized_text, cer = normalized_texts_cer[0] if cer > cer_threshold: return input_text, cer if verbose: print('-' * 30) for option in normalized_texts: print(option) print('-' * 30) return normalized_text, cer def calculate_cer(normalized_texts: List[str], pred_text: str, remove_punct=False) -> List[Tuple[str, float]]: """ Calculates character error rate (CER) Args: normalized_texts: normalized text options pred_text: ASR model output Returns: normalized options with corresponding CER """ normalized_options = [] for text in normalized_texts: text_clean = text.replace('-', ' ').lower() if remove_punct: for punct in "!?:;,.-()*+-/<=>@^_": text_clean = text_clean.replace(punct, "") cer = round(word_error_rate([pred_text], [text_clean], use_cer=True) * 100, 2) normalized_options.append((text, cer)) return normalized_options def get_asr_model(asr_model): """ Returns ASR Model Args: asr_model: NeMo ASR model """ if os.path.exists(args.model): asr_model = ASRModel.restore_from(asr_model) elif args.model in ASRModel.get_available_model_names(): asr_model = ASRModel.from_pretrained(asr_model) else: raise ValueError( f'Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}' ) return asr_model def parse_args(): parser = ArgumentParser() parser.add_argument("--text", help="input string or path to a .txt file", default=None, type=str) parser.add_argument( "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str ) parser.add_argument( "--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str ) parser.add_argument("--audio_data", default=None, help="path to an audio file or .json manifest") parser.add_argument( '--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint' ) parser.add_argument( "--n_tagged", type=int, default=30, help="number of tagged options to consider, -1 - return all possible tagged options", ) parser.add_argument("--verbose", help="print info for debugging", action="store_true") parser.add_argument( "--no_remove_punct_for_cer", help="Set to True to NOT remove punctuation before calculating CER", action="store_true", ) parser.add_argument( "--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true" ) parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true") parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str) parser.add_argument( "--cache_dir", help="path to a dir with .far grammar file. Set to None to avoid using cache", default=None, type=str, ) parser.add_argument("--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs") parser.add_argument( "--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now." ) parser.add_argument( "--cer_threshold", default=100, type=int, help="if CER for pred_text is above the cer_threshold, no normalization will be performed", ) parser.add_argument("--batch_size", default=200, type=int, help="Number of examples for each process") return parser.parse_args() def _normalize_line( normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold ): line = json.loads(line) pred_text = line["pred_text"] normalized_texts = normalizer.normalize( text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process, ) normalized_texts = set(normalized_texts) normalized_text, cer = normalizer.select_best_match( normalized_texts=normalized_texts, input_text=line["text"], pred_text=pred_text, verbose=verbose, remove_punct=remove_punct, cer_threshold=cer_threshold, ) line["nemo_normalized"] = normalized_text line["CER_nemo_normalized"] = cer return line def normalize_manifest( normalizer, audio_data: str, n_jobs: int, n_tagged: int, remove_punct: bool, punct_post_process: bool, batch_size: int, cer_threshold: int, ): """ Args: args.audio_data: path to .json manifest file. """ def __process_batch(batch_idx: int, batch: List[str], dir_name: str): """ Normalizes batch of text sequences Args: batch: list of texts batch_idx: batch index dir_name: path to output directory to save results """ normalized_lines = [ _normalize_line( normalizer, n_tagged, verbose=False, line=line, remove_punct=remove_punct, punct_post_process=punct_post_process, cer_threshold=cer_threshold, ) for line in tqdm(batch) ] with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out: for line in normalized_lines: f_out.write(json.dumps(line, ensure_ascii=False) + '\n') print(f"Batch -- {batch_idx} -- is complete") manifest_out = audio_data.replace('.json', '_normalized.json') with open(audio_data, 'r') as f: lines = f.readlines() print(f'Normalizing {len(lines)} lines of {audio_data}...') # to save intermediate results to a file batch = min(len(lines), batch_size) tmp_dir = manifest_out.replace(".json", "_parts") os.makedirs(tmp_dir, exist_ok=True) Parallel(n_jobs=n_jobs)( delayed(__process_batch)(idx, lines[i : i + batch], tmp_dir) for idx, i in enumerate(range(0, len(lines), batch)) ) # aggregate all intermediate files with open(manifest_out, "w") as f_out: for batch_f in sorted(glob(f"{tmp_dir}/*.json")): with open(batch_f, "r") as f_in: lines = f_in.read() f_out.write(lines) print(f'Normalized version saved at {manifest_out}') if __name__ == "__main__": args = parse_args() if not ASR_AVAILABLE and args.audio_data: raise ValueError("NeMo ASR collection is not installed.") start = time.time() args.whitelist = os.path.abspath(args.whitelist) if args.whitelist else None if args.text is not None: normalizer = NormalizerWithAudio( input_case=args.input_case, lang=args.language, cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache, whitelist=args.whitelist, lm=args.lm, ) if os.path.exists(args.text): with open(args.text, 'r') as f: args.text = f.read().strip() normalized_texts = normalizer.normalize( text=args.text, verbose=args.verbose, n_tagged=args.n_tagged, punct_post_process=not args.no_punct_post_process, ) if not normalizer.lm: normalized_texts = set(normalized_texts) if args.audio_data: asr_model = get_asr_model(args.model) pred_text = asr_model.transcribe([args.audio_data])[0] normalized_text, cer = normalizer.select_best_match( normalized_texts=normalized_texts, pred_text=pred_text, input_text=args.text, verbose=args.verbose, remove_punct=not args.no_remove_punct_for_cer, cer_threshold=args.cer_threshold, ) print(f"Transcript: {pred_text}") print(f"Normalized: {normalized_text}") else: print("Normalization options:") for norm_text in normalized_texts: print(norm_text) elif not os.path.exists(args.audio_data): raise ValueError(f"{args.audio_data} not found.") elif args.audio_data.endswith('.json'): normalizer = NormalizerWithAudio( input_case=args.input_case, lang=args.language, cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache, whitelist=args.whitelist, ) normalize_manifest( normalizer=normalizer, audio_data=args.audio_data, n_jobs=args.n_jobs, n_tagged=args.n_tagged, remove_punct=not args.no_remove_punct_for_cer, punct_post_process=not args.no_punct_post_process, batch_size=args.batch_size, cer_threshold=args.cer_threshold, ) else: raise ValueError( "Provide either path to .json manifest in '--audio_data' OR " + "'--audio_data' path to audio file and '--text' path to a text file OR" "'--text' string text (for debugging without audio)" ) print(f'Execution time: {round((time.time() - start)/60, 2)} min.')