From 6e6475cd2afebd5db41beef633645f154bb4cf05 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 16 十月 2024 14:35:56 +0800
Subject: [PATCH] funasr tables
---
fun_text_processing/text_normalization/normalize_with_audio.py | 112 ++++++++++++++++++++++++++++++++++++++++----------------
1 files changed, 80 insertions(+), 32 deletions(-)
diff --git a/fun_text_processing/text_normalization/normalize_with_audio.py b/fun_text_processing/text_normalization/normalize_with_audio.py
index 61e32f2..973a1d2 100644
--- a/fun_text_processing/text_normalization/normalize_with_audio.py
+++ b/fun_text_processing/text_normalization/normalize_with_audio.py
@@ -1,4 +1,3 @@
-
import json
import os
import time
@@ -72,7 +71,7 @@
def __init__(
self,
input_case: str,
- lang: str = 'en',
+ lang: str = "en",
cache_dir: str = None,
overwrite_cache: bool = False,
whitelist: str = None,
@@ -92,7 +91,13 @@
)
self.lm = lm
- def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False,) -> str:
+ def normalize(
+ self,
+ text: str,
+ n_tagged: int,
+ punct_post_process: bool = True,
+ verbose: bool = False,
+ ) -> str:
"""
Main function. Normalizes tokens from written to spoken form
e.g. 12 kg -> twelve kilograms
@@ -161,7 +166,8 @@
if self.processor:
normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts]
normalized_texts = [
- post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts
+ post_process_punct(input=original_text, normalized_text=t)
+ for t in normalized_texts
]
if self.lm:
@@ -199,9 +205,13 @@
else:
try:
# try self.tagger graph that produces output without digits
- tagged_texts = rewrite.top_rewrites(text, self.tagger.fst_no_digits, nshortest=n_tagged)
+ tagged_texts = rewrite.top_rewrites(
+ text, self.tagger.fst_no_digits, nshortest=n_tagged
+ )
except pynini.lib.rewrite.Error:
- tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
+ tagged_texts = rewrite.top_rewrites(
+ text, self.tagger.fst, nshortest=n_tagged
+ )
else:
tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
return tagged_texts
@@ -266,14 +276,16 @@
return input_text, cer
if verbose:
- print('-' * 30)
+ print("-" * 30)
for option in normalized_texts:
print(option)
- print('-' * 30)
+ print("-" * 30)
return normalized_text, cer
-def calculate_cer(normalized_texts: List[str], pred_text: str, remove_punct=False) -> List[Tuple[str, float]]:
+def calculate_cer(
+ normalized_texts: List[str], pred_text: str, remove_punct=False
+) -> List[Tuple[str, float]]:
"""
Calculates character error rate (CER)
@@ -285,7 +297,7 @@
"""
normalized_options = []
for text in normalized_texts:
- text_clean = text.replace('-', ' ').lower()
+ text_clean = text.replace("-", " ").lower()
if remove_punct:
for punct in "!?:;,.-()*+-/<=>@^_":
text_clean = text_clean.replace(punct, "")
@@ -307,23 +319,38 @@
asr_model = ASRModel.from_pretrained(asr_model)
else:
raise ValueError(
- f'Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}'
+ f"Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}"
)
return asr_model
def parse_args():
parser = ArgumentParser()
- parser.add_argument("--text", help="input string or path to a .txt file", default=None, type=str)
parser.add_argument(
- "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
+ "--text", help="input string or path to a .txt file", default=None, type=str
)
parser.add_argument(
- "--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str
+ "--input_case",
+ help="input capitalization",
+ choices=["lower_cased", "cased"],
+ default="cased",
+ type=str,
)
- parser.add_argument("--audio_data", default=None, help="path to an audio file or .json manifest")
parser.add_argument(
- '--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint'
+ "--language",
+ help="Select target language",
+ choices=["en", "ru", "de", "es"],
+ default="en",
+ type=str,
+ )
+ parser.add_argument(
+ "--audio_data", default=None, help="path to an audio file or .json manifest"
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="QuartzNet15x5Base-En",
+ help="Pre-trained model name or path to model checkpoint",
)
parser.add_argument(
"--n_tagged",
@@ -338,19 +365,29 @@
action="store_true",
)
parser.add_argument(
- "--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true"
+ "--no_punct_post_process",
+ help="set to True to disable punctuation post processing",
+ action="store_true",
)
- parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
- parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
+ parser.add_argument(
+ "--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true"
+ )
+ parser.add_argument(
+ "--whitelist", help="path to a file with with whitelist", default=None, type=str
+ )
parser.add_argument(
"--cache_dir",
help="path to a dir with .far grammar file. Set to None to avoid using cache",
default=None,
type=str,
)
- parser.add_argument("--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs")
parser.add_argument(
- "--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now."
+ "--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs"
+ )
+ parser.add_argument(
+ "--lm",
+ action="store_true",
+ help="Set to True for WFST+LM. Only available for English right now.",
)
parser.add_argument(
"--cer_threshold",
@@ -358,18 +395,29 @@
type=int,
help="if CER for pred_text is above the cer_threshold, no normalization will be performed",
)
- parser.add_argument("--batch_size", default=200, type=int, help="Number of examples for each process")
+ parser.add_argument(
+ "--batch_size", default=200, type=int, help="Number of examples for each process"
+ )
return parser.parse_args()
def _normalize_line(
- normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold
+ normalizer: NormalizerWithAudio,
+ n_tagged,
+ verbose,
+ line: str,
+ remove_punct,
+ punct_post_process,
+ cer_threshold,
):
line = json.loads(line)
pred_text = line["pred_text"]
normalized_texts = normalizer.normalize(
- text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process,
+ text=line["text"],
+ verbose=verbose,
+ n_tagged=n_tagged,
+ punct_post_process=punct_post_process,
)
normalized_texts = set(normalized_texts)
@@ -424,15 +472,15 @@
with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out:
for line in normalized_lines:
- f_out.write(json.dumps(line, ensure_ascii=False) + '\n')
+ f_out.write(json.dumps(line, ensure_ascii=False) + "\n")
print(f"Batch -- {batch_idx} -- is complete")
- manifest_out = audio_data.replace('.json', '_normalized.json')
- with open(audio_data, 'r') as f:
+ manifest_out = audio_data.replace(".json", "_normalized.json")
+ with open(audio_data, "r") as f:
lines = f.readlines()
- print(f'Normalizing {len(lines)} lines of {audio_data}...')
+ print(f"Normalizing {len(lines)} lines of {audio_data}...")
# to save intermediate results to a file
batch = min(len(lines), batch_size)
@@ -452,7 +500,7 @@
lines = f_in.read()
f_out.write(lines)
- print(f'Normalized version saved at {manifest_out}')
+ print(f"Normalized version saved at {manifest_out}")
if __name__ == "__main__":
@@ -473,7 +521,7 @@
)
if os.path.exists(args.text):
- with open(args.text, 'r') as f:
+ with open(args.text, "r") as f:
args.text = f.read().strip()
normalized_texts = normalizer.normalize(
text=args.text,
@@ -503,7 +551,7 @@
print(norm_text)
elif not os.path.exists(args.audio_data):
raise ValueError(f"{args.audio_data} not found.")
- elif args.audio_data.endswith('.json'):
+ elif args.audio_data.endswith(".json"):
normalizer = NormalizerWithAudio(
input_case=args.input_case,
lang=args.language,
@@ -527,4 +575,4 @@
+ "'--audio_data' path to audio file and '--text' path to a text file OR"
"'--text' string text (for debugging without audio)"
)
- print(f'Execution time: {round((time.time() - start)/60, 2)} min.')
+ print(f"Execution time: {round((time.time() - start)/60, 2)} min.")
--
Gitblit v1.9.1