From e6fe602db3eb1209543e55f1aafa2932dfda3310 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 一月 2025 10:14:30 +0800
Subject: [PATCH] step_or_epoch bugfix

---
 fun_text_processing/text_normalization/data_loader_utils.py |   83 +++++++++++++++++++++++------------------
 1 files changed, 47 insertions(+), 36 deletions(-)

diff --git a/fun_text_processing/text_normalization/data_loader_utils.py b/fun_text_processing/text_normalization/data_loader_utils.py
index 0b9de52..9b14540 100644
--- a/fun_text_processing/text_normalization/data_loader_utils.py
+++ b/fun_text_processing/text_normalization/data_loader_utils.py
@@ -1,6 +1,3 @@
-
-
-
 import json
 import re
 import string
@@ -14,7 +11,7 @@
 EOS_TYPE = "EOS"
 PUNCT_TYPE = "PUNCT"
 PLAIN_TYPE = "PLAIN"
-Instance = namedtuple('Instance', 'token_type un_normalized normalized')
+Instance = namedtuple("Instance", "token_type un_normalized normalized")
 known_types = [
     "PLAIN",
     "DATE",
@@ -38,7 +35,7 @@
     """
     https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
     Loads text file in the Kaggle Google text normalization file format: <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
-    E.g. 
+    E.g.
     PLAIN   Brillantaisia   <self>
     PLAIN   is      <self>
     PLAIN   a       <self>
@@ -54,10 +51,10 @@
     Args:
         file_path: file path to text file
 
-    Returns: flat list of instances 
+    Returns: flat list of instances
     """
     res = []
-    with open(file_path, 'r') as fp:
+    with open(file_path, "r") as fp:
         for line in fp:
             parts = line.strip().split("\t")
             if parts[0] == "<eos>":
@@ -68,9 +65,13 @@
                 l_normalized = l_normalized.lower()
 
                 if l_type == PLAIN_TYPE:
-                    res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_token))
+                    res.append(
+                        Instance(token_type=l_type, un_normalized=l_token, normalized=l_token)
+                    )
                 elif l_type != PUNCT_TYPE:
-                    res.append(Instance(token_type=l_type, un_normalized=l_token, normalized=l_normalized))
+                    res.append(
+                        Instance(token_type=l_type, un_normalized=l_token, normalized=l_normalized)
+                    )
     return res
 
 
@@ -78,7 +79,7 @@
     """
     Load given list of text files using the `load_func` function.
 
-    Args: 
+    Args:
         file_paths: list of file paths
         load_func: loading function
 
@@ -104,9 +105,11 @@
     return text
 
 
-def evaluate(preds: List[str], labels: List[str], input: Optional[List[str]] = None, verbose: bool = True) -> float:
+def evaluate(
+    preds: List[str], labels: List[str], input: Optional[List[str]] = None, verbose: bool = True
+) -> float:
     """
-    Evaluates accuracy given predictions and labels. 
+    Evaluates accuracy given predictions and labels.
 
     Args:
         preds: predictions
@@ -174,8 +177,12 @@
         else:
             sentence.append(instance)
             sentence_categories.update([instance.token_type])
-    un_normalized = [" ".join([instance.un_normalized for instance in sentence]) for sentence in sentences]
-    normalized = [" ".join([instance.normalized for instance in sentence]) for sentence in sentences]
+    un_normalized = [
+        " ".join([instance.un_normalized for instance in sentence]) for sentence in sentences
+    ]
+    normalized = [
+        " ".join([instance.normalized for instance in sentence]) for sentence in sentences
+    ]
     return un_normalized, normalized, categories
 
 
@@ -189,14 +196,14 @@
     Returns: text with normalized spaces and quotes
     """
     text = (
-        text.replace('( ', '(')
-        .replace(' )', ')')
-        .replace('{ ', '{')
-        .replace(' }', '}')
-        .replace('[ ', '[')
-        .replace(' ]', ']')
-        .replace('  ', ' ')
-        .replace('鈥�', '"')
+        text.replace("( ", "(")
+        .replace(" )", ")")
+        .replace("{ ", "{")
+        .replace(" }", "}")
+        .replace("[ ", "[")
+        .replace(" ]", "]")
+        .replace("  ", " ")
+        .replace("鈥�", '"')
         .replace("鈥�", "'")
         .replace("禄", '"')
         .replace("芦", '"')
@@ -204,14 +211,14 @@
         .replace("鈥�", '"')
         .replace("麓", "'")
         .replace("鈥�", "'")
-        .replace('鈥�', '"')
+        .replace("鈥�", '"')
         .replace("鈥�", "'")
-        .replace('`', "'")
-        .replace('- -', "--")
+        .replace("`", "'")
+        .replace("- -", "--")
     )
 
     for punct in "!,.:;?":
-        text = text.replace(f' {punct}', punct)
+        text = text.replace(f" {punct}", punct)
     return text.strip()
 
 
@@ -224,12 +231,12 @@
 
     Returns: text with spaces around punctuation marks
     """
-    space_both = '[]'
+    space_both = "[]"
     for punct in space_both:
-        text = text.replace(punct, ' ' + punct + ' ')
+        text = text.replace(punct, " " + punct + " ")
 
     # remove extra space
-    text = re.sub(r' +', ' ', text)
+    text = re.sub(r" +", " ", text)
     return text
 
 
@@ -237,13 +244,13 @@
     """
     Loads given text file with separate lines into list of string.
 
-    Args: 
+    Args:
         file_path: file path
 
     Returns: flat list of string
     """
     res = []
-    with open(file_path, 'r') as fp:
+    with open(file_path, "r") as fp:
         for line in fp:
             res.append(line)
     return res
@@ -256,11 +263,11 @@
     Args:
         file_path: file path
         data: list of string
-        
+
     """
-    with open(file_path, 'w') as fp:
+    with open(file_path, "w") as fp:
         for line in data:
-            fp.write(line + '\n')
+            fp.write(line + "\n")
 
 
 def post_process_punct(input: str, normalized_text: str, add_unicode_punct: bool = False):
@@ -309,7 +316,11 @@
                 def _is_valid(idx_out, idx_in, normalized_text, input):
                     """Check if previous or next word match (for cases when punctuation marks are part of
                     semiotic token, i.e. some punctuation can be missing in the normalized text)"""
-                    return (idx_out > 0 and idx_in > 0 and normalized_text[idx_out - 1] == input[idx_in - 1]) or (
+                    return (
+                        idx_out > 0
+                        and idx_in > 0
+                        and normalized_text[idx_out - 1] == input[idx_in - 1]
+                    ) or (
                         idx_out < len(normalized_text) - 1
                         and idx_in < len(input) - 1
                         and normalized_text[idx_out + 1] == input[idx_in + 1]
@@ -336,4 +347,4 @@
             logging.debug(f"Skipping post-processing of {''.join(normalized_text)} for '{punct}'")
 
     normalized_text = "".join(normalized_text)
-    return re.sub(r' +', ' ', normalized_text)
+    return re.sub(r" +", " ", normalized_text)

--
Gitblit v1.9.1