zhifu gao
2024-05-28 ba3a3bf4e67e861b833092d05d7c3842ea670cbc
fun_text_processing/text_normalization/data_loader_utils.py
@@ -1,6 +1,3 @@
import json
import re
import string
@@ -14,7 +11,7 @@
EOS_TYPE = "EOS"
PUNCT_TYPE = "PUNCT"
PLAIN_TYPE = "PLAIN"
Instance = namedtuple('Instance', 'token_type un_normalized normalized')
Instance = namedtuple("Instance", "token_type un_normalized normalized")
known_types = [
    "PLAIN",
    "DATE",
@@ -38,7 +35,7 @@
    """
    https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
    Loads text file in the Kaggle Google text normalization file format: <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
    E.g.
    E.g.
    PLAIN   Brillantaisia   <self>
    PLAIN   is      <self>
    PLAIN   a       <self>
@@ -54,10 +51,10 @@
    Args:
        file_path: file path to text file
    Returns: flat list of instances
    Returns: flat list of instances
    """
    res = []
    with open(file_path, 'r') as fp:
    with open(file_path, "r") as fp:
        for line in fp:
            parts = line.strip().split("\t")
            if parts[0] == "<eos>":
@@ -68,9 +65,13 @@
                l_normalized = l_normalized.lower()
                if l_type == PLAIN_TYPE:
                    res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_token))
                    res.append(
                        Instance(token_type=l_type, un_normalized=l_token, normalized=l_token)
                    )
                elif l_type != PUNCT_TYPE:
                    res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_normalized))
                    res.append(
                        Instance(token_type=l_type, un_normalized=l_token, normalized=l_normalized)
                    )
    return res
@@ -78,7 +79,7 @@
    """
    Load given list of text files using the `load_func` function.
    Args:
    Args:
        file_paths: list of file paths
        load_func: loading function
@@ -104,9 +105,11 @@
    return text
def evaluate(preds: List[str], labels: List[str], input: Optional[List[str]] = None, verbose: bool = True) -> float:
def evaluate(
    preds: List[str], labels: List[str], input: Optional[List[str]] = None, verbose: bool = True
) -> float:
    """
    Evaluates accuracy given predictions and labels.
    Evaluates accuracy given predictions and labels.
    Args:
        preds: predictions
@@ -174,8 +177,12 @@
        else:
            sentence.append(instance)
            sentence_categories.update([instance.token_type])
    un_normalized = [" ".join([instance.un_normalized for instance in sentence]) for sentence in sentences]
    normalized = [" ".join([instance.normalized for instance in sentence]) for sentence in sentences]
    un_normalized = [
        " ".join([instance.un_normalized for instance in sentence]) for sentence in sentences
    ]
    normalized = [
        " ".join([instance.normalized for instance in sentence]) for sentence in sentences
    ]
    return un_normalized, normalized, categories
@@ -189,14 +196,14 @@
    Returns: text with normalized spaces and quotes
    """
    text = (
        text.replace('( ', '(')
        .replace(' )', ')')
        .replace('{ ', '{')
        .replace(' }', '}')
        .replace('[ ', '[')
        .replace(' ]', ']')
        .replace('  ', ' ')
        .replace('”', '"')
        text.replace("( ", "(")
        .replace(" )", ")")
        .replace("{ ", "{")
        .replace(" }", "}")
        .replace("[ ", "[")
        .replace(" ]", "]")
        .replace("  ", " ")
        .replace("”", '"')
        .replace("’", "'")
        .replace("»", '"')
        .replace("«", '"')
@@ -204,14 +211,14 @@
        .replace("„", '"')
        .replace("´", "'")
        .replace("’", "'")
        .replace('“', '"')
        .replace("“", '"')
        .replace("‘", "'")
        .replace('`', "'")
        .replace('- -', "--")
        .replace("`", "'")
        .replace("- -", "--")
    )
    for punct in "!,.:;?":
        text = text.replace(f' {punct}', punct)
        text = text.replace(f" {punct}", punct)
    return text.strip()
@@ -224,12 +231,12 @@
    Returns: text with spaces around punctuation marks
    """
    space_both = '[]'
    space_both = "[]"
    for punct in space_both:
        text = text.replace(punct, ' ' + punct + ' ')
        text = text.replace(punct, " " + punct + " ")
    # remove extra space
    text = re.sub(r' +', ' ', text)
    text = re.sub(r" +", " ", text)
    return text
@@ -237,13 +244,13 @@
    """
    Loads given text file with separate lines into list of string.
    Args:
    Args:
        file_path: file path
    Returns: flat list of string
    """
    res = []
    with open(file_path, 'r') as fp:
    with open(file_path, "r") as fp:
        for line in fp:
            res.append(line)
    return res
@@ -256,11 +263,11 @@
    Args:
        file_path: file path
        data: list of string
    """
    with open(file_path, 'w') as fp:
    with open(file_path, "w") as fp:
        for line in data:
            fp.write(line + '\n')
            fp.write(line + "\n")
def post_process_punct(input: str, normalized_text: str, add_unicode_punct: bool = False):
@@ -309,7 +316,11 @@
                def _is_valid(idx_out, idx_in, normalized_text, input):
                    """Check if previous or next word match (for cases when punctuation marks are part of
                    semiotic token, i.e. some punctuation can be missing in the normalized text)"""
                    return (idx_out > 0 and idx_in > 0 and normalized_text[idx_out - 1] == input[idx_in - 1]) or (
                    return (
                        idx_out > 0
                        and idx_in > 0
                        and normalized_text[idx_out - 1] == input[idx_in - 1]
                    ) or (
                        idx_out < len(normalized_text) - 1
                        and idx_in < len(input) - 1
                        and normalized_text[idx_out + 1] == input[idx_in + 1]
@@ -336,4 +347,4 @@
            logging.debug(f"Skipping post-processing of {''.join(normalized_text)} for '{punct}'")
    normalized_text = "".join(normalized_text)
    return re.sub(r' +', ' ', normalized_text)
    return re.sub(r" +", " ", normalized_text)