From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 fun_text_processing/text_normalization/normalize_with_audio.py |  112 ++++++++++++++++++++++++++++++++++++++++----------------
 1 files changed, 80 insertions(+), 32 deletions(-)

diff --git a/fun_text_processing/text_normalization/normalize_with_audio.py b/fun_text_processing/text_normalization/normalize_with_audio.py
index 61e32f2..973a1d2 100644
--- a/fun_text_processing/text_normalization/normalize_with_audio.py
+++ b/fun_text_processing/text_normalization/normalize_with_audio.py
@@ -1,4 +1,3 @@
-
 import json
 import os
 import time
@@ -72,7 +71,7 @@
     def __init__(
         self,
         input_case: str,
-        lang: str = 'en',
+        lang: str = "en",
         cache_dir: str = None,
         overwrite_cache: bool = False,
         whitelist: str = None,
@@ -92,7 +91,13 @@
         )
         self.lm = lm
 
-    def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False,) -> str:
+    def normalize(
+        self,
+        text: str,
+        n_tagged: int,
+        punct_post_process: bool = True,
+        verbose: bool = False,
+    ) -> str:
         """
         Main function. Normalizes tokens from written to spoken form
             e.g. 12 kg -> twelve kilograms
@@ -161,7 +166,8 @@
             if self.processor:
                 normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts]
                 normalized_texts = [
-                    post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts
+                    post_process_punct(input=original_text, normalized_text=t)
+                    for t in normalized_texts
                 ]
 
         if self.lm:
@@ -199,9 +205,13 @@
                 else:
                     try:
                         # try self.tagger graph that produces output without digits
-                        tagged_texts = rewrite.top_rewrites(text, self.tagger.fst_no_digits, nshortest=n_tagged)
+                        tagged_texts = rewrite.top_rewrites(
+                            text, self.tagger.fst_no_digits, nshortest=n_tagged
+                        )
                     except pynini.lib.rewrite.Error:
-                        tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
+                        tagged_texts = rewrite.top_rewrites(
+                            text, self.tagger.fst, nshortest=n_tagged
+                        )
             else:
                 tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
         return tagged_texts
@@ -266,14 +276,16 @@
             return input_text, cer
 
         if verbose:
-            print('-' * 30)
+            print("-" * 30)
             for option in normalized_texts:
                 print(option)
-            print('-' * 30)
+            print("-" * 30)
         return normalized_text, cer
 
 
-def calculate_cer(normalized_texts: List[str], pred_text: str, remove_punct=False) -> List[Tuple[str, float]]:
+def calculate_cer(
+    normalized_texts: List[str], pred_text: str, remove_punct=False
+) -> List[Tuple[str, float]]:
     """
     Calculates character error rate (CER)
 
@@ -285,7 +297,7 @@
     """
     normalized_options = []
     for text in normalized_texts:
-        text_clean = text.replace('-', ' ').lower()
+        text_clean = text.replace("-", " ").lower()
         if remove_punct:
             for punct in "!?:;,.-()*+-/<=>@^_":
                 text_clean = text_clean.replace(punct, "")
@@ -307,23 +319,38 @@
         asr_model = ASRModel.from_pretrained(asr_model)
     else:
         raise ValueError(
-            f'Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}'
+            f"Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}"
         )
     return asr_model
 
 
 def parse_args():
     parser = ArgumentParser()
-    parser.add_argument("--text", help="input string or path to a .txt file", default=None, type=str)
     parser.add_argument(
-        "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
+        "--text", help="input string or path to a .txt file", default=None, type=str
     )
     parser.add_argument(
-        "--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str
+        "--input_case",
+        help="input capitalization",
+        choices=["lower_cased", "cased"],
+        default="cased",
+        type=str,
     )
-    parser.add_argument("--audio_data", default=None, help="path to an audio file or .json manifest")
     parser.add_argument(
-        '--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint'
+        "--language",
+        help="Select target language",
+        choices=["en", "ru", "de", "es"],
+        default="en",
+        type=str,
+    )
+    parser.add_argument(
+        "--audio_data", default=None, help="path to an audio file or .json manifest"
+    )
+    parser.add_argument(
+        "--model",
+        type=str,
+        default="QuartzNet15x5Base-En",
+        help="Pre-trained model name or path to model checkpoint",
     )
     parser.add_argument(
         "--n_tagged",
@@ -338,19 +365,29 @@
         action="store_true",
     )
     parser.add_argument(
-        "--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true"
+        "--no_punct_post_process",
+        help="set to True to disable punctuation post processing",
+        action="store_true",
     )
-    parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
-    parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
+    parser.add_argument(
+        "--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true"
+    )
+    parser.add_argument(
+        "--whitelist", help="path to a file with with whitelist", default=None, type=str
+    )
     parser.add_argument(
         "--cache_dir",
         help="path to a dir with .far grammar file. Set to None to avoid using cache",
         default=None,
         type=str,
     )
-    parser.add_argument("--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs")
     parser.add_argument(
-        "--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now."
+        "--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs"
+    )
+    parser.add_argument(
+        "--lm",
+        action="store_true",
+        help="Set to True for WFST+LM. Only available for English right now.",
     )
     parser.add_argument(
         "--cer_threshold",
@@ -358,18 +395,29 @@
         type=int,
         help="if CER for pred_text is above the cer_threshold, no normalization will be performed",
     )
-    parser.add_argument("--batch_size", default=200, type=int, help="Number of examples for each process")
+    parser.add_argument(
+        "--batch_size", default=200, type=int, help="Number of examples for each process"
+    )
     return parser.parse_args()
 
 
 def _normalize_line(
-    normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold
+    normalizer: NormalizerWithAudio,
+    n_tagged,
+    verbose,
+    line: str,
+    remove_punct,
+    punct_post_process,
+    cer_threshold,
 ):
     line = json.loads(line)
     pred_text = line["pred_text"]
 
     normalized_texts = normalizer.normalize(
-        text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process,
+        text=line["text"],
+        verbose=verbose,
+        n_tagged=n_tagged,
+        punct_post_process=punct_post_process,
     )
 
     normalized_texts = set(normalized_texts)
@@ -424,15 +472,15 @@
 
         with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out:
             for line in normalized_lines:
-                f_out.write(json.dumps(line, ensure_ascii=False) + '\n')
+                f_out.write(json.dumps(line, ensure_ascii=False) + "\n")
 
         print(f"Batch -- {batch_idx} -- is complete")
 
-    manifest_out = audio_data.replace('.json', '_normalized.json')
-    with open(audio_data, 'r') as f:
+    manifest_out = audio_data.replace(".json", "_normalized.json")
+    with open(audio_data, "r") as f:
         lines = f.readlines()
 
-    print(f'Normalizing {len(lines)} lines of {audio_data}...')
+    print(f"Normalizing {len(lines)} lines of {audio_data}...")
 
     # to save intermediate results to a file
     batch = min(len(lines), batch_size)
@@ -452,7 +500,7 @@
                 lines = f_in.read()
             f_out.write(lines)
 
-    print(f'Normalized version saved at {manifest_out}')
+    print(f"Normalized version saved at {manifest_out}")
 
 
 if __name__ == "__main__":
@@ -473,7 +521,7 @@
         )
 
         if os.path.exists(args.text):
-            with open(args.text, 'r') as f:
+            with open(args.text, "r") as f:
                 args.text = f.read().strip()
         normalized_texts = normalizer.normalize(
             text=args.text,
@@ -503,7 +551,7 @@
                 print(norm_text)
     elif not os.path.exists(args.audio_data):
         raise ValueError(f"{args.audio_data} not found.")
-    elif args.audio_data.endswith('.json'):
+    elif args.audio_data.endswith(".json"):
         normalizer = NormalizerWithAudio(
             input_case=args.input_case,
             lang=args.language,
@@ -527,4 +575,4 @@
             + "'--audio_data' path to audio file and '--text' path to a text file OR"
             "'--text' string text (for debugging without audio)"
         )
-    print(f'Execution time: {round((time.time() - start)/60, 2)} min.')
+    print(f"Execution time: {round((time.time() - start)/60, 2)} min.")

--
Gitblit v1.9.1