| | |
| | | |
| | | import json |
| | | import os |
| | | import time |
| | |
| | | def __init__( |
| | | self, |
| | | input_case: str, |
| | | lang: str = 'en', |
| | | lang: str = "en", |
| | | cache_dir: str = None, |
| | | overwrite_cache: bool = False, |
| | | whitelist: str = None, |
| | |
| | | ) |
| | | self.lm = lm |
| | | |
| | | def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False,) -> str: |
| | | def normalize( |
| | | self, |
| | | text: str, |
| | | n_tagged: int, |
| | | punct_post_process: bool = True, |
| | | verbose: bool = False, |
| | | ) -> str: |
| | | """ |
| | | Main function. Normalizes tokens from written to spoken form |
| | | e.g. 12 kg -> twelve kilograms |
| | |
| | | if self.processor: |
| | | normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts] |
| | | normalized_texts = [ |
| | | post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts |
| | | post_process_punct(input=original_text, normalized_text=t) |
| | | for t in normalized_texts |
| | | ] |
| | | |
| | | if self.lm: |
| | |
| | | else: |
| | | try: |
| | | # try self.tagger graph that produces output without digits |
| | | tagged_texts = rewrite.top_rewrites(text, self.tagger.fst_no_digits, nshortest=n_tagged) |
| | | tagged_texts = rewrite.top_rewrites( |
| | | text, self.tagger.fst_no_digits, nshortest=n_tagged |
| | | ) |
| | | except pynini.lib.rewrite.Error: |
| | | tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) |
| | | tagged_texts = rewrite.top_rewrites( |
| | | text, self.tagger.fst, nshortest=n_tagged |
| | | ) |
| | | else: |
| | | tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged) |
| | | return tagged_texts |
| | |
| | | return input_text, cer |
| | | |
| | | if verbose: |
| | | print('-' * 30) |
| | | print("-" * 30) |
| | | for option in normalized_texts: |
| | | print(option) |
| | | print('-' * 30) |
| | | print("-" * 30) |
| | | return normalized_text, cer |
| | | |
| | | |
| | | def calculate_cer(normalized_texts: List[str], pred_text: str, remove_punct=False) -> List[Tuple[str, float]]: |
| | | def calculate_cer( |
| | | normalized_texts: List[str], pred_text: str, remove_punct=False |
| | | ) -> List[Tuple[str, float]]: |
| | | """ |
| | | Calculates character error rate (CER) |
| | | |
| | |
| | | """ |
| | | normalized_options = [] |
| | | for text in normalized_texts: |
| | | text_clean = text.replace('-', ' ').lower() |
| | | text_clean = text.replace("-", " ").lower() |
| | | if remove_punct: |
| | | for punct in "!?:;,.-()*+-/<=>@^_": |
| | | text_clean = text_clean.replace(punct, "") |
| | |
| | | asr_model = ASRModel.from_pretrained(asr_model) |
| | | else: |
| | | raise ValueError( |
| | | f'Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}' |
| | | f"Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}" |
| | | ) |
| | | return asr_model |
| | | |
| | | |
| | | def parse_args(): |
| | | parser = ArgumentParser() |
| | | parser.add_argument("--text", help="input string or path to a .txt file", default=None, type=str) |
| | | parser.add_argument( |
| | | "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str |
| | | "--text", help="input string or path to a .txt file", default=None, type=str |
| | | ) |
| | | parser.add_argument( |
| | | "--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str |
| | | "--input_case", |
| | | help="input capitalization", |
| | | choices=["lower_cased", "cased"], |
| | | default="cased", |
| | | type=str, |
| | | ) |
| | | parser.add_argument("--audio_data", default=None, help="path to an audio file or .json manifest") |
| | | parser.add_argument( |
| | | '--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint' |
| | | "--language", |
| | | help="Select target language", |
| | | choices=["en", "ru", "de", "es"], |
| | | default="en", |
| | | type=str, |
| | | ) |
| | | parser.add_argument( |
| | | "--audio_data", default=None, help="path to an audio file or .json manifest" |
| | | ) |
| | | parser.add_argument( |
| | | "--model", |
| | | type=str, |
| | | default="QuartzNet15x5Base-En", |
| | | help="Pre-trained model name or path to model checkpoint", |
| | | ) |
| | | parser.add_argument( |
| | | "--n_tagged", |
| | |
| | | action="store_true", |
| | | ) |
| | | parser.add_argument( |
| | | "--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true" |
| | | "--no_punct_post_process", |
| | | help="set to True to disable punctuation post processing", |
| | | action="store_true", |
| | | ) |
| | | parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true") |
| | | parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str) |
| | | parser.add_argument( |
| | | "--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true" |
| | | ) |
| | | parser.add_argument( |
| | | "--whitelist", help="path to a file with with whitelist", default=None, type=str |
| | | ) |
| | | parser.add_argument( |
| | | "--cache_dir", |
| | | help="path to a dir with .far grammar file. Set to None to avoid using cache", |
| | | default=None, |
| | | type=str, |
| | | ) |
| | | parser.add_argument("--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs") |
| | | parser.add_argument( |
| | | "--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now." |
| | | "--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs" |
| | | ) |
| | | parser.add_argument( |
| | | "--lm", |
| | | action="store_true", |
| | | help="Set to True for WFST+LM. Only available for English right now.", |
| | | ) |
| | | parser.add_argument( |
| | | "--cer_threshold", |
| | |
| | | type=int, |
| | | help="if CER for pred_text is above the cer_threshold, no normalization will be performed", |
| | | ) |
| | | parser.add_argument("--batch_size", default=200, type=int, help="Number of examples for each process") |
| | | parser.add_argument( |
| | | "--batch_size", default=200, type=int, help="Number of examples for each process" |
| | | ) |
| | | return parser.parse_args() |
| | | |
| | | |
| | | def _normalize_line( |
| | | normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold |
| | | normalizer: NormalizerWithAudio, |
| | | n_tagged, |
| | | verbose, |
| | | line: str, |
| | | remove_punct, |
| | | punct_post_process, |
| | | cer_threshold, |
| | | ): |
| | | line = json.loads(line) |
| | | pred_text = line["pred_text"] |
| | | |
| | | normalized_texts = normalizer.normalize( |
| | | text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process, |
| | | text=line["text"], |
| | | verbose=verbose, |
| | | n_tagged=n_tagged, |
| | | punct_post_process=punct_post_process, |
| | | ) |
| | | |
| | | normalized_texts = set(normalized_texts) |
| | |
| | | |
| | | with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out: |
| | | for line in normalized_lines: |
| | | f_out.write(json.dumps(line, ensure_ascii=False) + '\n') |
| | | f_out.write(json.dumps(line, ensure_ascii=False) + "\n") |
| | | |
| | | print(f"Batch -- {batch_idx} -- is complete") |
| | | |
| | | manifest_out = audio_data.replace('.json', '_normalized.json') |
| | | with open(audio_data, 'r') as f: |
| | | manifest_out = audio_data.replace(".json", "_normalized.json") |
| | | with open(audio_data, "r") as f: |
| | | lines = f.readlines() |
| | | |
| | | print(f'Normalizing {len(lines)} lines of {audio_data}...') |
| | | print(f"Normalizing {len(lines)} lines of {audio_data}...") |
| | | |
| | | # to save intermediate results to a file |
| | | batch = min(len(lines), batch_size) |
| | |
| | | lines = f_in.read() |
| | | f_out.write(lines) |
| | | |
| | | print(f'Normalized version saved at {manifest_out}') |
| | | print(f"Normalized version saved at {manifest_out}") |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | |
| | | ) |
| | | |
| | | if os.path.exists(args.text): |
| | | with open(args.text, 'r') as f: |
| | | with open(args.text, "r") as f: |
| | | args.text = f.read().strip() |
| | | normalized_texts = normalizer.normalize( |
| | | text=args.text, |
| | |
| | | print(norm_text) |
| | | elif not os.path.exists(args.audio_data): |
| | | raise ValueError(f"{args.audio_data} not found.") |
| | | elif args.audio_data.endswith('.json'): |
| | | elif args.audio_data.endswith(".json"): |
| | | normalizer = NormalizerWithAudio( |
| | | input_case=args.input_case, |
| | | lang=args.language, |
| | |
| | | + "'--audio_data' path to audio file and '--text' path to a text file OR" |
| | | "'--text' string text (for debugging without audio)" |
| | | ) |
| | | print(f'Execution time: {round((time.time() - start)/60, 2)} min.') |
| | | print(f"Execution time: {round((time.time() - start)/60, 2)} min.") |