From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 examples/aishell/paraformer/utils/textnorm_zh.py |  447 ++++++++++++++++++++++++++++++++-----------------------
 1 files changed, 262 insertions(+), 185 deletions(-)

diff --git a/examples/aishell/paraformer/utils/textnorm_zh.py b/examples/aishell/paraformer/utils/textnorm_zh.py
index 79feb83..9de8e81 100755
--- a/examples/aishell/paraformer/utils/textnorm_zh.py
+++ b/examples/aishell/paraformer/utils/textnorm_zh.py
@@ -14,49 +14,58 @@
 # ================================================================================ #
 #                                    basic constant
 # ================================================================================ #
-CHINESE_DIGIS = u'闆朵竴浜屼笁鍥涗簲鍏竷鍏節'
-BIG_CHINESE_DIGIS_SIMPLIFIED = u'闆跺9璐板弫鑲嗕紞闄嗘煉鎹岀帠'
-BIG_CHINESE_DIGIS_TRADITIONAL = u'闆跺9璨冲弮鑲嗕紞闄告煉鎹岀帠'
-SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'鍗佺櫨鍗冧竾'
-SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'鎷句桨浠熻惉'
-LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'浜垮厗浜灀绉┌娌熸锭姝h浇'
-LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'鍎勫厗浜灀绉┌婧濇緱姝h級'
-SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'鍗佺櫨鍗冧竾'
-SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'鎷句桨浠熻惉'
+CHINESE_DIGIS = "闆朵竴浜屼笁鍥涗簲鍏竷鍏節"
+BIG_CHINESE_DIGIS_SIMPLIFIED = "闆跺9璐板弫鑲嗕紞闄嗘煉鎹岀帠"
+BIG_CHINESE_DIGIS_TRADITIONAL = "闆跺9璨冲弮鑲嗕紞闄告煉鎹岀帠"
+SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "鍗佺櫨鍗冧竾"
+SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "鎷句桨浠熻惉"
+LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "浜垮厗浜灀绉┌娌熸锭姝h浇"
+LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "鍎勫厗浜灀绉┌婧濇緱姝h級"
+SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "鍗佺櫨鍗冧竾"
+SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "鎷句桨浠熻惉"
 
-ZERO_ALT = u'銆�'
-ONE_ALT = u'骞�'
-TWO_ALTS = [u'涓�', u'鍏�']
+ZERO_ALT = "銆�"
+ONE_ALT = "骞�"
+TWO_ALTS = ["涓�", "鍏�"]
 
-POSITIVE = [u'姝�', u'姝�']
-NEGATIVE = [u'璐�', u'璨�']
-POINT = [u'鐐�', u'榛�']
+POSITIVE = ["姝�", "姝�"]
+NEGATIVE = ["璐�", "璨�"]
+POINT = ["鐐�", "榛�"]
 # PLUS = [u'鍔�', u'鍔�']
 # SIL = [u'鏉�', u'妲�']
 
-FILLER_CHARS = ['鍛�', '鍟�']
-ER_WHITELIST = '(鍎垮コ|鍎垮瓙|鍎垮瓩|濂冲効|鍎垮|濡诲効|' \
-             '鑳庡効|濠村効|鏂扮敓鍎縷濠村辜鍎縷骞煎効|灏戝効|灏忓効|鍎挎瓕|鍎跨|鍎跨|鎵樺効鎵�|瀛ゅ効|' \
-             '鍎挎垙|鍎垮寲|鍙板効搴剕楣垮効宀泑姝e効鍏粡|鍚婂効閮庡綋|鐢熷効鑲插コ|鎵樺効甯﹀コ|鍏诲効闃茶�亅鐥村効鍛嗗コ|' \
-             '浣冲効浣冲|鍎挎�滃吔鎵皘鍎挎棤甯哥埗|鍎夸笉瀚屾瘝涓憒鍎胯鍗冮噷姣嶆媴蹇鍎垮ぇ涓嶇敱鐖穦鑻忎篂鍎�)'
+FILLER_CHARS = ["鍛�", "鍟�"]
+ER_WHITELIST = (
+    "(鍎垮コ|鍎垮瓙|鍎垮瓩|濂冲効|鍎垮|濡诲効|"
+    "鑳庡効|濠村効|鏂扮敓鍎縷濠村辜鍎縷骞煎効|灏戝効|灏忓効|鍎挎瓕|鍎跨|鍎跨|鎵樺効鎵�|瀛ゅ効|"
+    "鍎挎垙|鍎垮寲|鍙板効搴剕楣垮効宀泑姝e効鍏粡|鍚婂効閮庡綋|鐢熷効鑲插コ|鎵樺効甯﹀コ|鍏诲効闃茶�亅鐥村効鍛嗗コ|"
+    "浣冲効浣冲|鍎挎�滃吔鎵皘鍎挎棤甯哥埗|鍎夸笉瀚屾瘝涓憒鍎胯鍗冮噷姣嶆媴蹇鍎垮ぇ涓嶇敱鐖穦鑻忎篂鍎�)"
+)
 
 # 涓枃鏁板瓧绯荤粺绫诲瀷
-NUMBERING_TYPES = ['low', 'mid', 'high']
+NUMBERING_TYPES = ["low", "mid", "high"]
 
-CURRENCY_NAMES = '(浜烘皯甯亅缇庡厓|鏃ュ厓|鑻遍晳|娆у厓|椹厠|娉曢儙|鍔犳嬁澶у厓|婢冲厓|娓竵|鍏堜护|鑺叞椹厠|鐖卞皵鍏伴晳|' \
-                 '閲屾媺|鑽峰叞鐩緗鍩冩柉搴撳|姣斿濉攟鍗板凹鐩緗鏋楀悏鐗箌鏂拌タ鍏板厓|姣旂储|鍗㈠竷|鏂板姞鍧″厓|闊╁厓|娉伴摙)'
-CURRENCY_UNITS = '((浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧�)|(浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧緗)鍏億(浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧緗)鍧梶瑙抾姣泑鍒�)'
-COM_QUANTIFIERS = '(鍖箌寮爘搴鍥瀨鍦簗灏緗鏉涓獆棣東闃檤闃祙缃憒鐐畖椤秥涓榺妫祙鍙獆鏀瘄琚瓅杈唡鎸憒鎷厊棰梶澹硘绐爘鏇瞸澧檤缇鑵攟' \
-                  '鐮搴瀹璐瘄鎵巪鎹唡鍒�|浠鎵搢鎵媩缃梶鍧灞眧宀瓅姹焲婧獆閽焲闃焲鍗晐鍙寍瀵箌鍑簗鍙澶磡鑴殀鏉縷璺硘鏋潀浠秥璐磡' \
-                  '閽坾绾縷绠鍚峾浣峾韬珅鍫倈璇緗鏈瑋椤祙瀹秥鎴穦灞倈涓潀姣珅鍘榺鍒唡閽眧涓鏂鎷厊閾鐭硘閽閿眧蹇絴(鍗億姣珅寰�)鍏媩' \
-                  '姣珅鍘榺鍒唡瀵竱灏簗涓坾閲寍瀵粅甯竱閾簗绋媩(鍗億鍒唡鍘榺姣珅寰�)绫硘鎾畖鍕簗鍚坾鍗噟鏂梶鐭硘鐩榺纰梶纰焲鍙爘妗秥绗紎鐩唡' \
-                  '鐩抾鏉瘄閽焲鏂泑閿厊绨媩绡畖鐩榺妗秥缃恷鐡秥澹秥鍗畖鐩弢绠﹟绠眧鐓瞸鍟東琚媩閽祙骞磡鏈坾鏃瀛鍒粅鏃秥鍛▅澶﹟绉抾鍒唡鏃瑋' \
-                  '绾獆宀亅涓東鏇磡澶渱鏄澶弢绉媩鍐瑋浠浼弢杈坾涓竱娉绮抾棰梶骞鍫唡鏉鏍箌鏀瘄閬搢闈鐗噟寮爘棰梶鍧�)'
+CURRENCY_NAMES = (
+    "(浜烘皯甯亅缇庡厓|鏃ュ厓|鑻遍晳|娆у厓|椹厠|娉曢儙|鍔犳嬁澶у厓|婢冲厓|娓竵|鍏堜护|鑺叞椹厠|鐖卞皵鍏伴晳|"
+    "閲屾媺|鑽峰叞鐩緗鍩冩柉搴撳|姣斿濉攟鍗板凹鐩緗鏋楀悏鐗箌鏂拌タ鍏板厓|姣旂储|鍗㈠竷|鏂板姞鍧″厓|闊╁厓|娉伴摙)"
+)
+CURRENCY_UNITS = (
+    "((浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧�)|(浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧緗)鍏億(浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧緗)鍧梶瑙抾姣泑鍒�)"
+)
+COM_QUANTIFIERS = (
+    "(鍖箌寮爘搴鍥瀨鍦簗灏緗鏉涓獆棣東闃檤闃祙缃憒鐐畖椤秥涓榺妫祙鍙獆鏀瘄琚瓅杈唡鎸憒鎷厊棰梶澹硘绐爘鏇瞸澧檤缇鑵攟"
+    "鐮搴瀹璐瘄鎵巪鎹唡鍒�|浠鎵搢鎵媩缃梶鍧灞眧宀瓅姹焲婧獆閽焲闃焲鍗晐鍙寍瀵箌鍑簗鍙澶磡鑴殀鏉縷璺硘鏋潀浠秥璐磡"
+    "閽坾绾縷绠鍚峾浣峾韬珅鍫倈璇緗鏈瑋椤祙瀹秥鎴穦灞倈涓潀姣珅鍘榺鍒唡閽眧涓鏂鎷厊閾鐭硘閽閿眧蹇絴(鍗億姣珅寰�)鍏媩"
+    "姣珅鍘榺鍒唡瀵竱灏簗涓坾閲寍瀵粅甯竱閾簗绋媩(鍗億鍒唡鍘榺姣珅寰�)绫硘鎾畖鍕簗鍚坾鍗噟鏂梶鐭硘鐩榺纰梶纰焲鍙爘妗秥绗紎鐩唡"
+    "鐩抾鏉瘄閽焲鏂泑閿厊绨媩绡畖鐩榺妗秥缃恷鐡秥澹秥鍗畖鐩弢绠﹟绠眧鐓瞸鍟東琚媩閽祙骞磡鏈坾鏃瀛鍒粅鏃秥鍛▅澶﹟绉抾鍒唡鏃瑋"
+    "绾獆宀亅涓東鏇磡澶渱鏄澶弢绉媩鍐瑋浠浼弢杈坾涓竱娉绮抾棰梶骞鍫唡鏉鏍箌鏀瘄閬搢闈鐗噟寮爘棰梶鍧�)"
+)
 
 # punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
-CHINESE_PUNC_STOP = '锛侊紵锝°��'
-CHINESE_PUNC_NON_STOP = '锛傦純锛勶紖锛嗭紘锛堬級锛婏紜锛岋紞锛忥細锛涳紲锛濓紴锛狅蓟锛硷冀锛撅伎锝�锝涳綔锝濓綖锝燂綘锝剑锝ゃ�併�冦�嬨�屻�嶃�庛�忋�愩�戙�斻�曘�栥�椼�樸�欍�氥�涖�溿�濄�炪�熴�般�俱�库�撯�斺�樷�欌�涒�溾�濃�炩�熲�︹�э箯'
+CHINESE_PUNC_STOP = "锛侊紵锝°��"
+CHINESE_PUNC_NON_STOP = "锛傦純锛勶紖锛嗭紘锛堬級锛婏紜锛岋紞锛忥細锛涳紲锛濓紴锛狅蓟锛硷冀锛撅伎锝�锝涳綔锝濓綖锝燂綘锝剑锝ゃ�併�冦�嬨�屻�嶃�庛�忋�愩�戙�斻�曘�栥�椼�樸�欍�氥�涖�溿�濄�炪�熴�般�俱�库�撯�斺�樷�欌�涒�溾�濃�炩�熲�︹�э箯"
 CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
+
 
 # ================================================================================ #
 #                                    basic class
@@ -72,7 +81,7 @@
     def __init__(self, simplified, traditional):
         self.simplified = simplified
         self.traditional = traditional
-        #self.__repr__ = self.__str__
+        # self.__repr__ = self.__str__
 
     def __str__(self):
         return self.simplified or self.traditional or None
@@ -95,26 +104,49 @@
         self.big_t = big_t
 
     def __str__(self):
-        return '10^{}'.format(self.power)
+        return "10^{}".format(self.power)
 
     @classmethod
     def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
 
         if small_unit:
-            return ChineseNumberUnit(power=index + 1,
-                                     simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
+            return ChineseNumberUnit(
+                power=index + 1,
+                simplified=value[0],
+                traditional=value[1],
+                big_s=value[1],
+                big_t=value[1],
+            )
         elif numbering_type == NUMBERING_TYPES[0]:
-            return ChineseNumberUnit(power=index + 8,
-                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+            return ChineseNumberUnit(
+                power=index + 8,
+                simplified=value[0],
+                traditional=value[1],
+                big_s=value[0],
+                big_t=value[1],
+            )
         elif numbering_type == NUMBERING_TYPES[1]:
-            return ChineseNumberUnit(power=(index + 2) * 4,
-                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+            return ChineseNumberUnit(
+                power=(index + 2) * 4,
+                simplified=value[0],
+                traditional=value[1],
+                big_s=value[0],
+                big_t=value[1],
+            )
         elif numbering_type == NUMBERING_TYPES[2]:
-            return ChineseNumberUnit(power=pow(2, index + 3),
-                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+            return ChineseNumberUnit(
+                power=pow(2, index + 3),
+                simplified=value[0],
+                traditional=value[1],
+                big_s=value[0],
+                big_t=value[1],
+            )
         else:
             raise ValueError(
-                'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
+                "Counting type should be in {0} ({1} provided).".format(
+                    NUMBERING_TYPES, numbering_type
+                )
+            )
 
 
 class ChineseNumberDigit(ChineseChar):
@@ -158,6 +190,7 @@
     """
     涓枃鏁板瓧绯荤粺
     """
+
     pass
 
 
@@ -207,27 +240,27 @@
 
     # chinese number units of '浜�' and larger
     all_larger_units = zip(
-        LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
-    larger_units = [CNU.create(i, v, numbering_type, False)
-                    for i, v in enumerate(all_larger_units)]
+        LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
+    )
+    larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)]
     # chinese number units of '鍗�, 鐧�, 鍗�, 涓�'
     all_smaller_units = zip(
-        SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
-    smaller_units = [CNU.create(i, v, small_unit=True)
-                     for i, v in enumerate(all_smaller_units)]
+        SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
+    )
+    smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)]
     # digis
-    chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
-                        BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
+    chinese_digis = zip(
+        CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL
+    )
     digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
     digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
     digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
     digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
 
     # symbols
-    positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
-    negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
-    point_cn = CM(POINT[0], POINT[1], '.', lambda x,
-                  y: float(str(x) + '.' + str(y)))
+    positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
+    negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
+    point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
     # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
     system = NumberSystem()
     system.units = smaller_units + larger_units
@@ -251,13 +284,14 @@
                 return m
 
     def string2symbols(chinese_string, system):
-        int_string, dec_string = chinese_string, ''
+        int_string, dec_string = chinese_string, ""
         for p in [system.math.point.simplified, system.math.point.traditional]:
             if p in chinese_string:
                 int_string, dec_string = chinese_string.split(p)
                 break
-        return [get_symbol(c, system) for c in int_string], \
-               [get_symbol(c, system) for c in dec_string]
+        return [get_symbol(c, system) for c in int_string], [
+            get_symbol(c, system) for c in dec_string
+        ]
 
     def correct_symbols(integer_symbols, system):
         """
@@ -271,8 +305,7 @@
 
         if len(integer_symbols) > 1:
             if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
-                integer_symbols.append(
-                    CNU(integer_symbols[-2].power - 1, None, None, None, None))
+                integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None))
 
         result = []
         unit_count = 0
@@ -288,9 +321,13 @@
                 result.append(current_unit)
             elif unit_count > 1:
                 for i in range(len(result)):
-                    if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
-                        result[-i - 1] = CNU(result[-i - 1].power +
-                                             current_unit.power, None, None, None, None)
+                    if (
+                        isinstance(result[-i - 1], CNU)
+                        and result[-i - 1].power < current_unit.power
+                    ):
+                        result[-i - 1] = CNU(
+                            result[-i - 1].power + current_unit.power, None, None, None, None
+                        )
         return result
 
     def compute_value(integer_symbols):
@@ -307,8 +344,7 @@
             elif isinstance(s, CNU):
                 value[-1] *= pow(10, s.power)
                 if s.power > last_power:
-                    value[:-1] = list(map(lambda v: v *
-                                                    pow(10, s.power), value[:-1]))
+                    value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
                     last_power = s.power
                 value.append(0)
         return sum(value)
@@ -317,20 +353,28 @@
     int_part, dec_part = string2symbols(chinese_string, system)
     int_part = correct_symbols(int_part, system)
     int_str = str(compute_value(int_part))
-    dec_str = ''.join([str(d.value) for d in dec_part])
+    dec_str = "".join([str(d.value) for d in dec_part])
     if dec_part:
-        return '{0}.{1}'.format(int_str, dec_str)
+        return "{0}.{1}".format(int_str, dec_str)
     else:
         return int_str
 
 
-def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
-            traditional=False, alt_zero=False, alt_one=False, alt_two=True,
-            use_zeros=True, use_units=True):
+def num2chn(
+    number_string,
+    numbering_type=NUMBERING_TYPES[1],
+    big=False,
+    traditional=False,
+    alt_zero=False,
+    alt_one=False,
+    alt_two=True,
+    use_zeros=True,
+    use_units=True,
+):
 
     def get_value(value_string, use_zeros=True):
 
-        striped_string = value_string.lstrip('0')
+        striped_string = value_string.lstrip("0")
 
         # record nothing if all zeros
         if not striped_string:
@@ -345,14 +389,17 @@
 
         # recursively record multiple digits
         else:
-            result_unit = next(u for u in reversed(
-                system.units) if u.power < len(striped_string))
-            result_string = value_string[:-result_unit.power]
-            return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
+            result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string))
+            result_string = value_string[: -result_unit.power]
+            return (
+                get_value(result_string)
+                + [result_unit]
+                + get_value(striped_string[-result_unit.power :])
+            )
 
     system = create_system(numbering_type)
 
-    int_dec = number_string.split('.')
+    int_dec = number_string.split(".")
     if len(int_dec) == 1:
         int_string = int_dec[0]
         dec_string = ""
@@ -361,7 +408,8 @@
         dec_string = int_dec[1]
     else:
         raise ValueError(
-            "invalid input num string with more than one dot: {}".format(number_string))
+            "invalid input num string with more than one dot: {}".format(number_string)
+        )
 
     if use_units and len(int_string) > 1:
         result_symbols = get_value(int_string)
@@ -372,51 +420,62 @@
         result_symbols += [system.math.point] + dec_symbols
 
     if alt_two:
-        liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
-                    system.digits[2].big_s, system.digits[2].big_t)
+        liang = CND(
+            2,
+            system.digits[2].alt_s,
+            system.digits[2].alt_t,
+            system.digits[2].big_s,
+            system.digits[2].big_t,
+        )
         for i, v in enumerate(result_symbols):
             if isinstance(v, CND) and v.value == 2:
-                next_symbol = result_symbols[i +
-                                             1] if i < len(result_symbols) - 1 else None
+                next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
                 previous_symbol = result_symbols[i - 1] if i > 0 else None
                 if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
-                    if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
+                    if next_symbol.power != 1 and (
+                        (previous_symbol is None) or (previous_symbol.power != 1)
+                    ):
                         result_symbols[i] = liang
 
     # if big is True, '涓�' will not be used and `alt_two` has no impact on output
     if big:
-        attr_name = 'big_'
+        attr_name = "big_"
         if traditional:
-            attr_name += 't'
+            attr_name += "t"
         else:
-            attr_name += 's'
+            attr_name += "s"
     else:
         if traditional:
-            attr_name = 'traditional'
+            attr_name = "traditional"
         else:
-            attr_name = 'simplified'
+            attr_name = "simplified"
 
-    result = ''.join([getattr(s, attr_name) for s in result_symbols])
+    result = "".join([getattr(s, attr_name) for s in result_symbols])
 
     # if not use_zeros:
     #     result = result.strip(getattr(system.digits[0], attr_name))
 
     if alt_zero:
-        result = result.replace(
-            getattr(system.digits[0], attr_name), system.digits[0].alt_s)
+        result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s)
 
     if alt_one:
-        result = result.replace(
-            getattr(system.digits[1], attr_name), system.digits[1].alt_s)
+        result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s)
 
     for i, p in enumerate(POINT):
         if result.startswith(p):
             return CHINESE_DIGIS[0] + result
 
     # ^10, 11, .., 19
-    if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
-                                          SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
-            result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
+    if (
+        len(result) >= 2
+        and result[1]
+        in [
+            SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
+            SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
+        ]
+        and result[0]
+        in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]
+    ):
         result = result[1:]
 
     return result
@@ -439,6 +498,7 @@
 
     def cardinal2chntext(self):
         return num2chn(self.cardinal)
+
 
 class Digit:
     """
@@ -476,17 +536,17 @@
     def telephone2chntext(self, fixed=False):
 
         if fixed:
-            sil_parts = self.telephone.split('-')
-            self.raw_chntext = '<SIL>'.join([
-                num2chn(part, alt_two=False, use_units=False) for part in sil_parts
-            ])
-            self.chntext = self.raw_chntext.replace('<SIL>', '')
+            sil_parts = self.telephone.split("-")
+            self.raw_chntext = "<SIL>".join(
+                [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
+            )
+            self.chntext = self.raw_chntext.replace("<SIL>", "")
         else:
-            sp_parts = self.telephone.strip('+').split()
-            self.raw_chntext = '<SP>'.join([
-                num2chn(part, alt_two=False, use_units=False) for part in sp_parts
-            ])
-            self.chntext = self.raw_chntext.replace('<SP>', '')
+            sp_parts = self.telephone.strip("+").split()
+            self.raw_chntext = "<SP>".join(
+                [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
+            )
+            self.chntext = self.raw_chntext.replace("<SP>", "")
         return self.chntext
 
 
@@ -500,12 +560,12 @@
         self.chntext = chntext
 
     def chntext2fraction(self):
-        denominator, numerator = self.chntext.split('鍒嗕箣')
-        return chn2num(numerator) + '/' + chn2num(denominator)
+        denominator, numerator = self.chntext.split("鍒嗕箣")
+        return chn2num(numerator) + "/" + chn2num(denominator)
 
     def fraction2chntext(self):
-        numerator, denominator = self.fraction.split('/')
-        return num2chn(denominator) + '鍒嗕箣' + num2chn(numerator)
+        numerator, denominator = self.fraction.split("/")
+        return num2chn(denominator) + "鍒嗕箣" + num2chn(numerator)
 
 
 class Date:
@@ -544,23 +604,23 @@
     def date2chntext(self):
         date = self.date
         try:
-            year, other = date.strip().split('骞�', 1)
-            year = Digit(digit=year).digit2chntext() + '骞�'
+            year, other = date.strip().split("骞�", 1)
+            year = Digit(digit=year).digit2chntext() + "骞�"
         except ValueError:
             other = date
-            year = ''
+            year = ""
         if other:
             try:
-                month, day = other.strip().split('鏈�', 1)
-                month = Cardinal(cardinal=month).cardinal2chntext() + '鏈�'
+                month, day = other.strip().split("鏈�", 1)
+                month = Cardinal(cardinal=month).cardinal2chntext() + "鏈�"
             except ValueError:
                 day = date
-                month = ''
+                month = ""
             if day:
                 day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
         else:
-            month = ''
-            day = ''
+            month = ""
+            day = ""
         chntext = year + month + day
         self.chntext = chntext
         return self.chntext
@@ -580,7 +640,7 @@
 
     def money2chntext(self):
         money = self.money
-        pattern = re.compile(r'(\d+(\.\d+)?)')
+        pattern = re.compile(r"(\d+(\.\d+)?)")
         matchers = pattern.findall(money)
         if matchers:
             for matcher in matchers:
@@ -599,10 +659,10 @@
         self.chntext = chntext
 
     def chntext2percentage(self):
-        return chn2num(self.chntext.strip().strip('鐧惧垎涔�')) + '%'
+        return chn2num(self.chntext.strip().strip("鐧惧垎涔�")) + "%"
 
     def percentage2chntext(self):
-        return '鐧惧垎涔�' + num2chn(self.percentage.strip().strip('%'))
+        return "鐧惧垎涔�" + num2chn(self.percentage.strip().strip("%"))
 
 
 def remove_erhua(text, er_whitelist):
@@ -612,9 +672,9 @@
     """
 
     er_pattern = re.compile(er_whitelist)
-    new_str=''
-    while re.search('鍎�',text):
-        a = re.search('鍎�',text).span()
+    new_str = ""
+    while re.search("鍎�", text):
+        a = re.search("鍎�", text).span()
         remove_er_flag = 0
 
         if er_pattern.search(text):
@@ -622,23 +682,24 @@
             if b[0] <= a[0]:
                 remove_er_flag = 1
 
-        if remove_er_flag == 0 :
-            new_str = new_str + text[0:a[0]]
-            text = text[a[1]:]
+        if remove_er_flag == 0:
+            new_str = new_str + text[0 : a[0]]
+            text = text[a[1] :]
         else:
-            new_str = new_str + text[0:b[1]]
-            text = text[b[1]:]
+            new_str = new_str + text[0 : b[1]]
+            text = text[b[1] :]
 
     text = new_str + text
     return text
+
 
 # ================================================================================ #
 #                            NSW Normalizer
 # ================================================================================ #
 class NSWNormalizer:
     def __init__(self, raw_text):
-        self.raw_text = '^' + raw_text + '$'
-        self.norm_text = ''
+        self.raw_text = "^" + raw_text + "$"
+        self.norm_text = ""
 
     def _particular(self):
         text = self.norm_text
@@ -647,7 +708,7 @@
         if matchers:
             # print('particular')
             for matcher in matchers:
-                text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1)
+                text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
         self.norm_text = text
         return self.norm_text
 
@@ -658,15 +719,17 @@
         pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})骞�)?(\d{1,2}鏈�(\d{1,2}[鏃ュ彿])?)?)")
         matchers = pattern.findall(text)
         if matchers:
-            #print('date')
+            # print('date')
             for matcher in matchers:
                 text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
 
         # 瑙勮寖鍖栭噾閽�
-        pattern = re.compile(r"\D+((\d+(\.\d+)?)[澶氫綑鍑燷?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
+        pattern = re.compile(
+            r"\D+((\d+(\.\d+)?)[澶氫綑鍑燷?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)"
+        )
         matchers = pattern.findall(text)
         if matchers:
-            #print('money')
+            # print('money')
             for matcher in matchers:
                 text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
 
@@ -679,39 +742,45 @@
         pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
         matchers = pattern.findall(text)
         if matchers:
-            #print('telephone')
+            # print('telephone')
             for matcher in matchers:
-                text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
+                text = text.replace(
+                    matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
+                )
         # 鍥鸿瘽
         pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
         matchers = pattern.findall(text)
         if matchers:
             # print('fixed telephone')
             for matcher in matchers:
-                text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
+                text = text.replace(
+                    matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1
+                )
 
         # 瑙勮寖鍖栧垎鏁�
         pattern = re.compile(r"(\d+/\d+)")
         matchers = pattern.findall(text)
         if matchers:
-            #print('fraction')
+            # print('fraction')
             for matcher in matchers:
                 text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
 
         # 瑙勮寖鍖栫櫨鍒嗘暟
-        text = text.replace('锛�', '%')
+        text = text.replace("锛�", "%")
         pattern = re.compile(r"(\d+(\.\d+)?%)")
         matchers = pattern.findall(text)
         if matchers:
-            #print('percentage')
+            # print('percentage')
             for matcher in matchers:
-                text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
+                text = text.replace(
+                    matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1
+                )
 
         # 瑙勮寖鍖栫函鏁�+閲忚瘝
         pattern = re.compile(r"(\d+(\.\d+)?)[澶氫綑鍑燷?" + COM_QUANTIFIERS)
         matchers = pattern.findall(text)
         if matchers:
-            #print('cardinal+quantifier')
+            # print('cardinal+quantifier')
             for matcher in matchers:
                 text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
 
@@ -719,7 +788,7 @@
         pattern = re.compile(r"(\d{4,32})")
         matchers = pattern.findall(text)
         if matchers:
-            #print('digit')
+            # print('digit')
             for matcher in matchers:
                 text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
 
@@ -727,74 +796,82 @@
         pattern = re.compile(r"(\d+(\.\d+)?)")
         matchers = pattern.findall(text)
         if matchers:
-            #print('cardinal')
+            # print('cardinal')
             for matcher in matchers:
                 text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
 
         self.norm_text = text
         self._particular()
 
-        return self.norm_text.lstrip('^').rstrip('$')
+        return self.norm_text.lstrip("^").rstrip("$")
 
 
 def nsw_test_case(raw_text):
-    print('I:' + raw_text)
-    print('O:' + NSWNormalizer(raw_text).normalize())
-    print('')
+    print("I:" + raw_text)
+    print("O:" + NSWNormalizer(raw_text).normalize())
+    print("")
 
 
 def nsw_test():
-    nsw_test_case('鍥鸿瘽锛�0595-23865596鎴�23880880銆�')
-    nsw_test_case('鍥鸿瘽锛�0595-23865596鎴�23880880銆�')
-    nsw_test_case('鎵嬫満锛�+86 19859213959鎴�15659451527銆�')
-    nsw_test_case('鍒嗘暟锛�32477/76391銆�')
-    nsw_test_case('鐧惧垎鏁帮細80.03%銆�')
-    nsw_test_case('缂栧彿锛�31520181154418銆�')
-    nsw_test_case('绾暟锛�2983.07鍏嬫垨12345.60绫炽��')
-    nsw_test_case('鏃ユ湡锛�1999骞�2鏈�20鏃ユ垨09骞�3鏈�15鍙枫��')
-    nsw_test_case('閲戦挶锛�12鍧�5锛�34.5鍏冿紝20.1涓�')
-    nsw_test_case('鐗规畩锛歄2O鎴朆2C銆�')
-    nsw_test_case('3456涓囧惃')
-    nsw_test_case('2938涓�')
-    nsw_test_case('938')
-    nsw_test_case('浠婂ぉ鍚冧簡115涓皬绗煎寘231涓澶�')
-    nsw_test_case('鏈�62锛呯殑姒傜巼')
+    nsw_test_case("鍥鸿瘽锛�0595-23865596鎴�23880880銆�")
+    nsw_test_case("鍥鸿瘽锛�0595-23865596鎴�23880880銆�")
+    nsw_test_case("鎵嬫満锛�+86 19859213959鎴�15659451527銆�")
+    nsw_test_case("鍒嗘暟锛�32477/76391銆�")
+    nsw_test_case("鐧惧垎鏁帮細80.03%銆�")
+    nsw_test_case("缂栧彿锛�31520181154418銆�")
+    nsw_test_case("绾暟锛�2983.07鍏嬫垨12345.60绫炽��")
+    nsw_test_case("鏃ユ湡锛�1999骞�2鏈�20鏃ユ垨09骞�3鏈�15鍙枫��")
+    nsw_test_case("閲戦挶锛�12鍧�5锛�34.5鍏冿紝20.1涓�")
+    nsw_test_case("鐗规畩锛歄2O鎴朆2C銆�")
+    nsw_test_case("3456涓囧惃")
+    nsw_test_case("2938涓�")
+    nsw_test_case("938")
+    nsw_test_case("浠婂ぉ鍚冧簡115涓皬绗煎寘231涓澶�")
+    nsw_test_case("鏈�62锛呯殑姒傜巼")
 
 
-if __name__ == '__main__':
-    #nsw_test()
+if __name__ == "__main__":
+    # nsw_test()
 
     p = argparse.ArgumentParser()
-    p.add_argument('ifile', help='input filename, assume utf-8 encoding')
-    p.add_argument('ofile', help='output filename')
-    p.add_argument('--to_upper', action='store_true', help='convert to upper case')
-    p.add_argument('--to_lower', action='store_true', help='convert to lower case')
-    p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
-    p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "鍛�, 鍟�"')
-    p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "杩欏効"')
-    p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
+    p.add_argument("ifile", help="input filename, assume utf-8 encoding")
+    p.add_argument("ofile", help="output filename")
+    p.add_argument("--to_upper", action="store_true", help="convert to upper case")
+    p.add_argument("--to_lower", action="store_true", help="convert to lower case")
+    p.add_argument(
+        "--has_key", action="store_true", help="input text has Kaldi's key as first field."
+    )
+    p.add_argument(
+        "--remove_fillers", type=bool, default=True, help='remove filler chars such as "鍛�, 鍟�"'
+    )
+    p.add_argument(
+        "--remove_erhua", type=bool, default=True, help='remove erhua chars such as "杩欏効"'
+    )
+    p.add_argument(
+        "--log_interval", type=int, default=10000, help="log interval in number of processed lines"
+    )
     args = p.parse_args()
 
-    ifile = codecs.open(args.ifile, 'r', 'utf8')
-    ofile = codecs.open(args.ofile, 'w+', 'utf8')
+    ifile = codecs.open(args.ifile, "r", "utf8")
+    ofile = codecs.open(args.ofile, "w+", "utf8")
 
     n = 0
     for l in ifile:
-        key = ''
-        text = ''
+        key = ""
+        text = ""
         if args.has_key:
             cols = l.split(maxsplit=1)
             key = cols[0]
             if len(cols) == 2:
                 text = cols[1].strip()
             else:
-                text = ''
+                text = ""
         else:
             text = l.strip()
 
         # cases
         if args.to_upper and args.to_lower:
-            sys.stderr.write('text norm: to_upper OR to_lower?')
+            sys.stderr.write("text norm: to_upper OR to_lower?")
             exit(1)
         if args.to_upper:
             text = text.upper()
@@ -804,7 +881,7 @@
         # Filler chars removal
         if args.remove_fillers:
             for ch in FILLER_CHARS:
-                text = text.replace(ch, '')
+                text = text.replace(ch, "")
 
         if args.remove_erhua:
             text = remove_erhua(text, ER_WHITELIST)
@@ -813,16 +890,16 @@
         text = NSWNormalizer(text).normalize()
 
         # Punctuations removal
-        old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
-        new_chars = ' ' * len(old_chars)
-        del_chars = ''
+        old_chars = CHINESE_PUNC_LIST + string.punctuation  # includes all CN and EN punctuations
+        new_chars = " " * len(old_chars)
+        del_chars = ""
         text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
 
         #
         if args.has_key:
-            ofile.write(key + '\t' + text + '\n')
+            ofile.write(key + "\t" + text + "\n")
         else:
-            ofile.write(text + '\n')
+            ofile.write(text + "\n")
 
         n += 1
         if n % args.log_interval == 0:

--
Gitblit v1.9.1