游雁
2024-10-16 6e6475cd2afebd5db41beef633645f154bb4cf05
fun_text_processing/text_normalization/normalize_with_audio.py
@@ -1,4 +1,3 @@
import json
import os
import time
@@ -72,7 +71,7 @@
    def __init__(
        self,
        input_case: str,
        lang: str = 'en',
        lang: str = "en",
        cache_dir: str = None,
        overwrite_cache: bool = False,
        whitelist: str = None,
@@ -92,7 +91,13 @@
        )
        self.lm = lm
    def normalize(self, text: str, n_tagged: int, punct_post_process: bool = True, verbose: bool = False,) -> str:
    def normalize(
        self,
        text: str,
        n_tagged: int,
        punct_post_process: bool = True,
        verbose: bool = False,
    ) -> str:
        """
        Main function. Normalizes tokens from written to spoken form
            e.g. 12 kg -> twelve kilograms
@@ -161,7 +166,8 @@
            if self.processor:
                normalized_texts = [self.processor.detokenize([t]) for t in normalized_texts]
                normalized_texts = [
                    post_process_punct(input=original_text, normalized_text=t) for t in normalized_texts
                    post_process_punct(input=original_text, normalized_text=t)
                    for t in normalized_texts
                ]
        if self.lm:
@@ -199,9 +205,13 @@
                else:
                    try:
                        # try self.tagger graph that produces output without digits
                        tagged_texts = rewrite.top_rewrites(text, self.tagger.fst_no_digits, nshortest=n_tagged)
                        tagged_texts = rewrite.top_rewrites(
                            text, self.tagger.fst_no_digits, nshortest=n_tagged
                        )
                    except pynini.lib.rewrite.Error:
                        tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
                        tagged_texts = rewrite.top_rewrites(
                            text, self.tagger.fst, nshortest=n_tagged
                        )
            else:
                tagged_texts = rewrite.top_rewrites(text, self.tagger.fst, nshortest=n_tagged)
        return tagged_texts
@@ -266,14 +276,16 @@
            return input_text, cer
        if verbose:
            print('-' * 30)
            print("-" * 30)
            for option in normalized_texts:
                print(option)
            print('-' * 30)
            print("-" * 30)
        return normalized_text, cer
def calculate_cer(normalized_texts: List[str], pred_text: str, remove_punct=False) -> List[Tuple[str, float]]:
def calculate_cer(
    normalized_texts: List[str], pred_text: str, remove_punct=False
) -> List[Tuple[str, float]]:
    """
    Calculates character error rate (CER)
@@ -285,7 +297,7 @@
    """
    normalized_options = []
    for text in normalized_texts:
        text_clean = text.replace('-', ' ').lower()
        text_clean = text.replace("-", " ").lower()
        if remove_punct:
            for punct in "!?:;,.-()*+-/<=>@^_":
                text_clean = text_clean.replace(punct, "")
@@ -307,23 +319,38 @@
        asr_model = ASRModel.from_pretrained(asr_model)
    else:
        raise ValueError(
            f'Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}'
            f"Provide path to the pretrained checkpoint or choose from {ASRModel.get_available_model_names()}"
        )
    return asr_model
def parse_args():
    parser = ArgumentParser()
    parser.add_argument("--text", help="input string or path to a .txt file", default=None, type=str)
    parser.add_argument(
        "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
        "--text", help="input string or path to a .txt file", default=None, type=str
    )
    parser.add_argument(
        "--language", help="Select target language", choices=["en", "ru", "de", "es"], default="en", type=str
        "--input_case",
        help="input capitalization",
        choices=["lower_cased", "cased"],
        default="cased",
        type=str,
    )
    parser.add_argument("--audio_data", default=None, help="path to an audio file or .json manifest")
    parser.add_argument(
        '--model', type=str, default='QuartzNet15x5Base-En', help='Pre-trained model name or path to model checkpoint'
        "--language",
        help="Select target language",
        choices=["en", "ru", "de", "es"],
        default="en",
        type=str,
    )
    parser.add_argument(
        "--audio_data", default=None, help="path to an audio file or .json manifest"
    )
    parser.add_argument(
        "--model",
        type=str,
        default="QuartzNet15x5Base-En",
        help="Pre-trained model name or path to model checkpoint",
    )
    parser.add_argument(
        "--n_tagged",
@@ -338,19 +365,29 @@
        action="store_true",
    )
    parser.add_argument(
        "--no_punct_post_process", help="set to True to disable punctuation post processing", action="store_true"
        "--no_punct_post_process",
        help="set to True to disable punctuation post processing",
        action="store_true",
    )
    parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
    parser.add_argument("--whitelist", help="path to a file with with whitelist", default=None, type=str)
    parser.add_argument(
        "--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true"
    )
    parser.add_argument(
        "--whitelist", help="path to a file with with whitelist", default=None, type=str
    )
    parser.add_argument(
        "--cache_dir",
        help="path to a dir with .far grammar file. Set to None to avoid using cache",
        default=None,
        type=str,
    )
    parser.add_argument("--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs")
    parser.add_argument(
        "--lm", action="store_true", help="Set to True for WFST+LM. Only available for English right now."
        "--n_jobs", default=-2, type=int, help="The maximum number of concurrently running jobs"
    )
    parser.add_argument(
        "--lm",
        action="store_true",
        help="Set to True for WFST+LM. Only available for English right now.",
    )
    parser.add_argument(
        "--cer_threshold",
@@ -358,18 +395,29 @@
        type=int,
        help="if CER for pred_text is above the cer_threshold, no normalization will be performed",
    )
    parser.add_argument("--batch_size", default=200, type=int, help="Number of examples for each process")
    parser.add_argument(
        "--batch_size", default=200, type=int, help="Number of examples for each process"
    )
    return parser.parse_args()
def _normalize_line(
    normalizer: NormalizerWithAudio, n_tagged, verbose, line: str, remove_punct, punct_post_process, cer_threshold
    normalizer: NormalizerWithAudio,
    n_tagged,
    verbose,
    line: str,
    remove_punct,
    punct_post_process,
    cer_threshold,
):
    line = json.loads(line)
    pred_text = line["pred_text"]
    normalized_texts = normalizer.normalize(
        text=line["text"], verbose=verbose, n_tagged=n_tagged, punct_post_process=punct_post_process,
        text=line["text"],
        verbose=verbose,
        n_tagged=n_tagged,
        punct_post_process=punct_post_process,
    )
    normalized_texts = set(normalized_texts)
@@ -424,15 +472,15 @@
        with open(f"{dir_name}/{batch_idx:05}.json", "w") as f_out:
            for line in normalized_lines:
                f_out.write(json.dumps(line, ensure_ascii=False) + '\n')
                f_out.write(json.dumps(line, ensure_ascii=False) + "\n")
        print(f"Batch -- {batch_idx} -- is complete")
    manifest_out = audio_data.replace('.json', '_normalized.json')
    with open(audio_data, 'r') as f:
    manifest_out = audio_data.replace(".json", "_normalized.json")
    with open(audio_data, "r") as f:
        lines = f.readlines()
    print(f'Normalizing {len(lines)} lines of {audio_data}...')
    print(f"Normalizing {len(lines)} lines of {audio_data}...")
    # to save intermediate results to a file
    batch = min(len(lines), batch_size)
@@ -452,7 +500,7 @@
                lines = f_in.read()
            f_out.write(lines)
    print(f'Normalized version saved at {manifest_out}')
    print(f"Normalized version saved at {manifest_out}")
if __name__ == "__main__":
@@ -473,7 +521,7 @@
        )
        if os.path.exists(args.text):
            with open(args.text, 'r') as f:
            with open(args.text, "r") as f:
                args.text = f.read().strip()
        normalized_texts = normalizer.normalize(
            text=args.text,
@@ -503,7 +551,7 @@
                print(norm_text)
    elif not os.path.exists(args.audio_data):
        raise ValueError(f"{args.audio_data} not found.")
    elif args.audio_data.endswith('.json'):
    elif args.audio_data.endswith(".json"):
        normalizer = NormalizerWithAudio(
            input_case=args.input_case,
            lang=args.language,
@@ -527,4 +575,4 @@
            + "'--audio_data' path to audio file and '--text' path to a text file OR"
            "'--text' string text (for debugging without audio)"
        )
    print(f'Execution time: {round((time.time() - start)/60, 2)} min.')
    print(f"Execution time: {round((time.time() - start)/60, 2)} min.")