From 559cc2c6e296bc80917a7408911f671dfcc2b68b Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 12 五月 2023 17:25:54 +0800
Subject: [PATCH] update repo

---
 egs/aishell2/transformer/utils/textnorm_zh.py |  834 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 834 insertions(+), 0 deletions(-)

diff --git a/egs/aishell2/transformer/utils/textnorm_zh.py b/egs/aishell2/transformer/utils/textnorm_zh.py
new file mode 100755
index 0000000..79feb83
--- /dev/null
+++ b/egs/aishell2/transformer/utils/textnorm_zh.py
@@ -0,0 +1,834 @@
+#!/usr/bin/env python3
+# coding=utf-8
+
+# Authors:
+#   2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
+#   2019.9 Jiayu DU
+#
+# requirements:
+#   - python 3.X
+# notes: python 2.X WILL fail or produce misleading results
+
+import sys, os, argparse, codecs, string, re
+
+# ================================================================================ #
+#                                    basic constant
+# ================================================================================ #
+CHINESE_DIGIS = u'闆朵竴浜屼笁鍥涗簲鍏竷鍏節'
+BIG_CHINESE_DIGIS_SIMPLIFIED = u'闆跺9璐板弫鑲嗕紞闄嗘煉鎹岀帠'
+BIG_CHINESE_DIGIS_TRADITIONAL = u'闆跺9璨冲弮鑲嗕紞闄告煉鎹岀帠'
+SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'鍗佺櫨鍗冧竾'
+SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'鎷句桨浠熻惉'
+LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'浜垮厗浜灀绉┌娌熸锭姝h浇'
+LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'鍎勫厗浜灀绉┌婧濇緱姝h級'
+SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'鍗佺櫨鍗冧竾'
+SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'鎷句桨浠熻惉'
+
+ZERO_ALT = u'銆�'
+ONE_ALT = u'骞�'
+TWO_ALTS = [u'涓�', u'鍏�']
+
+POSITIVE = [u'姝�', u'姝�']
+NEGATIVE = [u'璐�', u'璨�']
+POINT = [u'鐐�', u'榛�']
+# PLUS = [u'鍔�', u'鍔�']
+# SIL = [u'鏉�', u'妲�']
+
+FILLER_CHARS = ['鍛�', '鍟�']
+ER_WHITELIST = '(鍎垮コ|鍎垮瓙|鍎垮瓩|濂冲効|鍎垮|濡诲効|' \
+             '鑳庡効|濠村効|鏂扮敓鍎縷濠村辜鍎縷骞煎効|灏戝効|灏忓効|鍎挎瓕|鍎跨|鍎跨|鎵樺効鎵�|瀛ゅ効|' \
+             '鍎挎垙|鍎垮寲|鍙板効搴剕楣垮効宀泑姝e効鍏粡|鍚婂効閮庡綋|鐢熷効鑲插コ|鎵樺効甯﹀コ|鍏诲効闃茶�亅鐥村効鍛嗗コ|' \
+             '浣冲効浣冲|鍎挎�滃吔鎵皘鍎挎棤甯哥埗|鍎夸笉瀚屾瘝涓憒鍎胯鍗冮噷姣嶆媴蹇鍎垮ぇ涓嶇敱鐖穦鑻忎篂鍎�)'
+
+# 涓枃鏁板瓧绯荤粺绫诲瀷
+NUMBERING_TYPES = ['low', 'mid', 'high']
+
+CURRENCY_NAMES = '(浜烘皯甯亅缇庡厓|鏃ュ厓|鑻遍晳|娆у厓|椹厠|娉曢儙|鍔犳嬁澶у厓|婢冲厓|娓竵|鍏堜护|鑺叞椹厠|鐖卞皵鍏伴晳|' \
+                 '閲屾媺|鑽峰叞鐩緗鍩冩柉搴撳|姣斿濉攟鍗板凹鐩緗鏋楀悏鐗箌鏂拌タ鍏板厓|姣旂储|鍗㈠竷|鏂板姞鍧″厓|闊╁厓|娉伴摙)'
+CURRENCY_UNITS = '((浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧�)|(浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧緗)鍏億(浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧緗)鍧梶瑙抾姣泑鍒�)'
+COM_QUANTIFIERS = '(鍖箌寮爘搴鍥瀨鍦簗灏緗鏉涓獆棣東闃檤闃祙缃憒鐐畖椤秥涓榺妫祙鍙獆鏀瘄琚瓅杈唡鎸憒鎷厊棰梶澹硘绐爘鏇瞸澧檤缇鑵攟' \
+                  '鐮搴瀹璐瘄鎵巪鎹唡鍒�|浠鎵搢鎵媩缃梶鍧灞眧宀瓅姹焲婧獆閽焲闃焲鍗晐鍙寍瀵箌鍑簗鍙澶磡鑴殀鏉縷璺硘鏋潀浠秥璐磡' \
+                  '閽坾绾縷绠鍚峾浣峾韬珅鍫倈璇緗鏈瑋椤祙瀹秥鎴穦灞倈涓潀姣珅鍘榺鍒唡閽眧涓鏂鎷厊閾鐭硘閽閿眧蹇絴(鍗億姣珅寰�)鍏媩' \
+                  '姣珅鍘榺鍒唡瀵竱灏簗涓坾閲寍瀵粅甯竱閾簗绋媩(鍗億鍒唡鍘榺姣珅寰�)绫硘鎾畖鍕簗鍚坾鍗噟鏂梶鐭硘鐩榺纰梶纰焲鍙爘妗秥绗紎鐩唡' \
+                  '鐩抾鏉瘄閽焲鏂泑閿厊绨媩绡畖鐩榺妗秥缃恷鐡秥澹秥鍗畖鐩弢绠﹟绠眧鐓瞸鍟東琚媩閽祙骞磡鏈坾鏃瀛鍒粅鏃秥鍛▅澶﹟绉抾鍒唡鏃瑋' \
+                  '绾獆宀亅涓東鏇磡澶渱鏄澶弢绉媩鍐瑋浠浼弢杈坾涓竱娉绮抾棰梶骞鍫唡鏉鏍箌鏀瘄閬搢闈鐗噟寮爘棰梶鍧�)'
+
+# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
+CHINESE_PUNC_STOP = '锛侊紵锝°��'
+CHINESE_PUNC_NON_STOP = '锛傦純锛勶紖锛嗭紘锛堬級锛婏紜锛岋紞锛忥細锛涳紲锛濓紴锛狅蓟锛硷冀锛撅伎锝�锝涳綔锝濓綖锝燂綘锝剑锝ゃ�併�冦�嬨�屻�嶃�庛�忋�愩�戙�斻�曘�栥�椼�樸�欍�氥�涖�溿�濄�炪�熴�般�俱�库�撯�斺�樷�欌�涒�溾�濃�炩�熲�︹�э箯'
+CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
+
+# ================================================================================ #
+#                                    basic class
+# ================================================================================ #
+class ChineseChar(object):
+    """
+    涓枃瀛楃
+    姣忎釜瀛楃瀵瑰簲绠�浣撳拰绻佷綋,
+    e.g. 绠�浣� = '璐�', 绻佷綋 = '璨�'
+    杞崲鏃跺彲杞崲涓虹畝浣撴垨绻佷綋
+    """
+
+    def __init__(self, simplified, traditional):
+        self.simplified = simplified
+        self.traditional = traditional
+        #self.__repr__ = self.__str__
+
+    def __str__(self):
+        return self.simplified or self.traditional or None
+
+    def __repr__(self):
+        return self.__str__()
+
+
+class ChineseNumberUnit(ChineseChar):
+    """
+    涓枃鏁板瓧/鏁颁綅瀛楃
+    姣忎釜瀛楃闄ょ箒绠�浣撳杩樻湁涓�涓澶栫殑澶у啓瀛楃
+    e.g. '闄�' 鍜� '闄�'
+    """
+
+    def __init__(self, power, simplified, traditional, big_s, big_t):
+        super(ChineseNumberUnit, self).__init__(simplified, traditional)
+        self.power = power
+        self.big_s = big_s
+        self.big_t = big_t
+
+    def __str__(self):
+        return '10^{}'.format(self.power)
+
+    @classmethod
+    def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
+
+        if small_unit:
+            return ChineseNumberUnit(power=index + 1,
+                                     simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
+        elif numbering_type == NUMBERING_TYPES[0]:
+            return ChineseNumberUnit(power=index + 8,
+                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+        elif numbering_type == NUMBERING_TYPES[1]:
+            return ChineseNumberUnit(power=(index + 2) * 4,
+                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+        elif numbering_type == NUMBERING_TYPES[2]:
+            return ChineseNumberUnit(power=pow(2, index + 3),
+                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+        else:
+            raise ValueError(
+                'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
+
+
+class ChineseNumberDigit(ChineseChar):
+    """
+    涓枃鏁板瓧瀛楃
+    """
+
+    def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
+        super(ChineseNumberDigit, self).__init__(simplified, traditional)
+        self.value = value
+        self.big_s = big_s
+        self.big_t = big_t
+        self.alt_s = alt_s
+        self.alt_t = alt_t
+
+    def __str__(self):
+        return str(self.value)
+
+    @classmethod
+    def create(cls, i, v):
+        return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
+
+
+class ChineseMath(ChineseChar):
+    """
+    涓枃鏁颁綅瀛楃
+    """
+
+    def __init__(self, simplified, traditional, symbol, expression=None):
+        super(ChineseMath, self).__init__(simplified, traditional)
+        self.symbol = symbol
+        self.expression = expression
+        self.big_s = simplified
+        self.big_t = traditional
+
+
+CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
+
+
+class NumberSystem(object):
+    """
+    涓枃鏁板瓧绯荤粺
+    """
+    pass
+
+
+class MathSymbol(object):
+    """
+    鐢ㄤ簬涓枃鏁板瓧绯荤粺鐨勬暟瀛︾鍙� (绻�/绠�浣�), e.g.
+    positive = ['姝�', '姝�']
+    negative = ['璐�', '璨�']
+    point = ['鐐�', '榛�']
+    """
+
+    def __init__(self, positive, negative, point):
+        self.positive = positive
+        self.negative = negative
+        self.point = point
+
+    def __iter__(self):
+        for v in self.__dict__.values():
+            yield v
+
+
+# class OtherSymbol(object):
+#     """
+#     鍏朵粬绗﹀彿
+#     """
+#
+#     def __init__(self, sil):
+#         self.sil = sil
+#
+#     def __iter__(self):
+#         for v in self.__dict__.values():
+#             yield v
+
+
+# ================================================================================ #
+#                                    basic utils
+# ================================================================================ #
+def create_system(numbering_type=NUMBERING_TYPES[1]):
+    """
+    鏍规嵁鏁板瓧绯荤粺绫诲瀷杩斿洖鍒涘缓鐩稿簲鐨勬暟瀛楃郴缁燂紝榛樿涓� mid
+    NUMBERING_TYPES = ['low', 'mid', 'high']: 涓枃鏁板瓧绯荤粺绫诲瀷
+        low:  '鍏�' = '浜�' * '鍗�' = $10^{9}$,  '浜�' = '鍏�' * '鍗�', etc.
+        mid:  '鍏�' = '浜�' * '涓�' = $10^{12}$, '浜�' = '鍏�' * '涓�', etc.
+        high: '鍏�' = '浜�' * '浜�' = $10^{16}$, '浜�' = '鍏�' * '鍏�', etc.
+    杩斿洖瀵瑰簲鐨勬暟瀛楃郴缁�
+    """
+
+    # chinese number units of '浜�' and larger
+    all_larger_units = zip(
+        LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
+    larger_units = [CNU.create(i, v, numbering_type, False)
+                    for i, v in enumerate(all_larger_units)]
+    # chinese number units of '鍗�, 鐧�, 鍗�, 涓�'
+    all_smaller_units = zip(
+        SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
+    smaller_units = [CNU.create(i, v, small_unit=True)
+                     for i, v in enumerate(all_smaller_units)]
+    # digis
+    chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
+                        BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
+    digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
+    digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
+    digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
+    digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
+
+    # symbols
+    positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
+    negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
+    point_cn = CM(POINT[0], POINT[1], '.', lambda x,
+                  y: float(str(x) + '.' + str(y)))
+    # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
+    system = NumberSystem()
+    system.units = smaller_units + larger_units
+    system.digits = digits
+    system.math = MathSymbol(positive_cn, negative_cn, point_cn)
+    # system.symbols = OtherSymbol(sil_cn)
+    return system
+
+
+def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
+
+    def get_symbol(char, system):
+        for u in system.units:
+            if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
+                return u
+        for d in system.digits:
+            if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
+                return d
+        for m in system.math:
+            if char in [m.traditional, m.simplified]:
+                return m
+
+    def string2symbols(chinese_string, system):
+        int_string, dec_string = chinese_string, ''
+        for p in [system.math.point.simplified, system.math.point.traditional]:
+            if p in chinese_string:
+                int_string, dec_string = chinese_string.split(p)
+                break
+        return [get_symbol(c, system) for c in int_string], \
+               [get_symbol(c, system) for c in dec_string]
+
+    def correct_symbols(integer_symbols, system):
+        """
+        涓�鐧惧叓 to 涓�鐧惧叓鍗�
+        涓�浜夸竴鍗冧笁鐧句竾 to 涓�浜� 涓�鍗冧竾 涓夌櫨涓�
+        """
+
+        if integer_symbols and isinstance(integer_symbols[0], CNU):
+            if integer_symbols[0].power == 1:
+                integer_symbols = [system.digits[1]] + integer_symbols
+
+        if len(integer_symbols) > 1:
+            if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
+                integer_symbols.append(
+                    CNU(integer_symbols[-2].power - 1, None, None, None, None))
+
+        result = []
+        unit_count = 0
+        for s in integer_symbols:
+            if isinstance(s, CND):
+                result.append(s)
+                unit_count = 0
+            elif isinstance(s, CNU):
+                current_unit = CNU(s.power, None, None, None, None)
+                unit_count += 1
+
+            if unit_count == 1:
+                result.append(current_unit)
+            elif unit_count > 1:
+                for i in range(len(result)):
+                    if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
+                        result[-i - 1] = CNU(result[-i - 1].power +
+                                             current_unit.power, None, None, None, None)
+        return result
+
+    def compute_value(integer_symbols):
+        """
+        Compute the value.
+        When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
+        e.g. '涓ゅ崈涓�' = 2000 * 10000 not 2000 + 10000
+        """
+        value = [0]
+        last_power = 0
+        for s in integer_symbols:
+            if isinstance(s, CND):
+                value[-1] = s.value
+            elif isinstance(s, CNU):
+                value[-1] *= pow(10, s.power)
+                if s.power > last_power:
+                    value[:-1] = list(map(lambda v: v *
+                                                    pow(10, s.power), value[:-1]))
+                    last_power = s.power
+                value.append(0)
+        return sum(value)
+
+    system = create_system(numbering_type)
+    int_part, dec_part = string2symbols(chinese_string, system)
+    int_part = correct_symbols(int_part, system)
+    int_str = str(compute_value(int_part))
+    dec_str = ''.join([str(d.value) for d in dec_part])
+    if dec_part:
+        return '{0}.{1}'.format(int_str, dec_str)
+    else:
+        return int_str
+
+
+def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
+            traditional=False, alt_zero=False, alt_one=False, alt_two=True,
+            use_zeros=True, use_units=True):
+
+    def get_value(value_string, use_zeros=True):
+
+        striped_string = value_string.lstrip('0')
+
+        # record nothing if all zeros
+        if not striped_string:
+            return []
+
+        # record one digits
+        elif len(striped_string) == 1:
+            if use_zeros and len(value_string) != len(striped_string):
+                return [system.digits[0], system.digits[int(striped_string)]]
+            else:
+                return [system.digits[int(striped_string)]]
+
+        # recursively record multiple digits
+        else:
+            result_unit = next(u for u in reversed(
+                system.units) if u.power < len(striped_string))
+            result_string = value_string[:-result_unit.power]
+            return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
+
+    system = create_system(numbering_type)
+
+    int_dec = number_string.split('.')
+    if len(int_dec) == 1:
+        int_string = int_dec[0]
+        dec_string = ""
+    elif len(int_dec) == 2:
+        int_string = int_dec[0]
+        dec_string = int_dec[1]
+    else:
+        raise ValueError(
+            "invalid input num string with more than one dot: {}".format(number_string))
+
+    if use_units and len(int_string) > 1:
+        result_symbols = get_value(int_string)
+    else:
+        result_symbols = [system.digits[int(c)] for c in int_string]
+    dec_symbols = [system.digits[int(c)] for c in dec_string]
+    if dec_string:
+        result_symbols += [system.math.point] + dec_symbols
+
+    if alt_two:
+        liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
+                    system.digits[2].big_s, system.digits[2].big_t)
+        for i, v in enumerate(result_symbols):
+            if isinstance(v, CND) and v.value == 2:
+                next_symbol = result_symbols[i +
+                                             1] if i < len(result_symbols) - 1 else None
+                previous_symbol = result_symbols[i - 1] if i > 0 else None
+                if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
+                    if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
+                        result_symbols[i] = liang
+
+    # if big is True, '涓�' will not be used and `alt_two` has no impact on output
+    if big:
+        attr_name = 'big_'
+        if traditional:
+            attr_name += 't'
+        else:
+            attr_name += 's'
+    else:
+        if traditional:
+            attr_name = 'traditional'
+        else:
+            attr_name = 'simplified'
+
+    result = ''.join([getattr(s, attr_name) for s in result_symbols])
+
+    # if not use_zeros:
+    #     result = result.strip(getattr(system.digits[0], attr_name))
+
+    if alt_zero:
+        result = result.replace(
+            getattr(system.digits[0], attr_name), system.digits[0].alt_s)
+
+    if alt_one:
+        result = result.replace(
+            getattr(system.digits[1], attr_name), system.digits[1].alt_s)
+
+    for i, p in enumerate(POINT):
+        if result.startswith(p):
+            return CHINESE_DIGIS[0] + result
+
+    # ^10, 11, .., 19
+    if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
+                                          SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
+            result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
+        result = result[1:]
+
+    return result
+
+
+# ================================================================================ #
+#                          different types of rewriters
+# ================================================================================ #
+class Cardinal:
+    """
+    CARDINAL绫�
+    """
+
+    def __init__(self, cardinal=None, chntext=None):
+        self.cardinal = cardinal
+        self.chntext = chntext
+
+    def chntext2cardinal(self):
+        return chn2num(self.chntext)
+
+    def cardinal2chntext(self):
+        return num2chn(self.cardinal)
+
+class Digit:
+    """
+    DIGIT绫�
+    """
+
+    def __init__(self, digit=None, chntext=None):
+        self.digit = digit
+        self.chntext = chntext
+
+    # def chntext2digit(self):
+    #     return chn2num(self.chntext)
+
+    def digit2chntext(self):
+        return num2chn(self.digit, alt_two=False, use_units=False)
+
+
+class TelePhone:
+    """
+    TELEPHONE绫�
+    """
+
+    def __init__(self, telephone=None, raw_chntext=None, chntext=None):
+        self.telephone = telephone
+        self.raw_chntext = raw_chntext
+        self.chntext = chntext
+
+    # def chntext2telephone(self):
+    #     sil_parts = self.raw_chntext.split('<SIL>')
+    #     self.telephone = '-'.join([
+    #         str(chn2num(p)) for p in sil_parts
+    #     ])
+    #     return self.telephone
+
+    def telephone2chntext(self, fixed=False):
+
+        if fixed:
+            sil_parts = self.telephone.split('-')
+            self.raw_chntext = '<SIL>'.join([
+                num2chn(part, alt_two=False, use_units=False) for part in sil_parts
+            ])
+            self.chntext = self.raw_chntext.replace('<SIL>', '')
+        else:
+            sp_parts = self.telephone.strip('+').split()
+            self.raw_chntext = '<SP>'.join([
+                num2chn(part, alt_two=False, use_units=False) for part in sp_parts
+            ])
+            self.chntext = self.raw_chntext.replace('<SP>', '')
+        return self.chntext
+
+
+class Fraction:
+    """
+    FRACTION绫�
+    """
+
+    def __init__(self, fraction=None, chntext=None):
+        self.fraction = fraction
+        self.chntext = chntext
+
+    def chntext2fraction(self):
+        denominator, numerator = self.chntext.split('鍒嗕箣')
+        return chn2num(numerator) + '/' + chn2num(denominator)
+
+    def fraction2chntext(self):
+        numerator, denominator = self.fraction.split('/')
+        return num2chn(denominator) + '鍒嗕箣' + num2chn(numerator)
+
+
+class Date:
+    """
+    DATE绫�
+    """
+
+    def __init__(self, date=None, chntext=None):
+        self.date = date
+        self.chntext = chntext
+
+    # def chntext2date(self):
+    #     chntext = self.chntext
+    #     try:
+    #         year, other = chntext.strip().split('骞�', maxsplit=1)
+    #         year = Digit(chntext=year).digit2chntext() + '骞�'
+    #     except ValueError:
+    #         other = chntext
+    #         year = ''
+    #     if other:
+    #         try:
+    #             month, day = other.strip().split('鏈�', maxsplit=1)
+    #             month = Cardinal(chntext=month).chntext2cardinal() + '鏈�'
+    #         except ValueError:
+    #             day = chntext
+    #             month = ''
+    #         if day:
+    #             day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
+    #     else:
+    #         month = ''
+    #         day = ''
+    #     date = year + month + day
+    #     self.date = date
+    #     return self.date
+
+    def date2chntext(self):
+        date = self.date
+        try:
+            year, other = date.strip().split('骞�', 1)
+            year = Digit(digit=year).digit2chntext() + '骞�'
+        except ValueError:
+            other = date
+            year = ''
+        if other:
+            try:
+                month, day = other.strip().split('鏈�', 1)
+                month = Cardinal(cardinal=month).cardinal2chntext() + '鏈�'
+            except ValueError:
+                day = date
+                month = ''
+            if day:
+                day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
+        else:
+            month = ''
+            day = ''
+        chntext = year + month + day
+        self.chntext = chntext
+        return self.chntext
+
+
+class Money:
+    """
+    MONEY绫�
+    """
+
+    def __init__(self, money=None, chntext=None):
+        self.money = money
+        self.chntext = chntext
+
+    # def chntext2money(self):
+    #     return self.money
+
+    def money2chntext(self):
+        money = self.money
+        pattern = re.compile(r'(\d+(\.\d+)?)')
+        matchers = pattern.findall(money)
+        if matchers:
+            for matcher in matchers:
+                money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
+        self.chntext = money
+        return self.chntext
+
+
+class Percentage:
+    """
+    PERCENTAGE绫�
+    """
+
+    def __init__(self, percentage=None, chntext=None):
+        self.percentage = percentage
+        self.chntext = chntext
+
+    def chntext2percentage(self):
+        return chn2num(self.chntext.strip().strip('鐧惧垎涔�')) + '%'
+
+    def percentage2chntext(self):
+        return '鐧惧垎涔�' + num2chn(self.percentage.strip().strip('%'))
+
+
+def remove_erhua(text, er_whitelist):
+    """
+    鍘婚櫎鍎垮寲闊宠瘝涓殑鍎�:
+    浠栧コ鍎垮湪閭h竟鍎� -> 浠栧コ鍎垮湪閭h竟
+    """
+
+    er_pattern = re.compile(er_whitelist)
+    new_str=''
+    while re.search('鍎�',text):
+        a = re.search('鍎�',text).span()
+        remove_er_flag = 0
+
+        if er_pattern.search(text):
+            b = er_pattern.search(text).span()
+            if b[0] <= a[0]:
+                remove_er_flag = 1
+
+        if remove_er_flag == 0 :
+            new_str = new_str + text[0:a[0]]
+            text = text[a[1]:]
+        else:
+            new_str = new_str + text[0:b[1]]
+            text = text[b[1]:]
+
+    text = new_str + text
+    return text
+
+# ================================================================================ #
+#                            NSW Normalizer
+# ================================================================================ #
+class NSWNormalizer:
+    def __init__(self, raw_text):
+        self.raw_text = '^' + raw_text + '$'
+        self.norm_text = ''
+
+    def _particular(self):
+        text = self.norm_text
+        pattern = re.compile(r"(([a-zA-Z]+)浜�([a-zA-Z]+))")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('particular')
+            for matcher in matchers:
+                text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1)
+        self.norm_text = text
+        return self.norm_text
+
+    def normalize(self):
+        text = self.raw_text
+
+        # 瑙勮寖鍖栨棩鏈�
+        pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})骞�)?(\d{1,2}鏈�(\d{1,2}[鏃ュ彿])?)?)")
+        matchers = pattern.findall(text)
+        if matchers:
+            #print('date')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
+
+        # 瑙勮寖鍖栭噾閽�
+        pattern = re.compile(r"\D+((\d+(\.\d+)?)[澶氫綑鍑燷?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
+        matchers = pattern.findall(text)
+        if matchers:
+            #print('money')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
+
+        # 瑙勮寖鍖栧浐璇�/鎵嬫満鍙风爜
+        # 鎵嬫満
+        # http://www.jihaoba.com/news/show/13680
+        # 绉诲姩锛�139銆�138銆�137銆�136銆�135銆�134銆�159銆�158銆�157銆�150銆�151銆�152銆�188銆�187銆�182銆�183銆�184銆�178銆�198
+        # 鑱旈�氾細130銆�131銆�132銆�156銆�155銆�186銆�185銆�176
+        # 鐢典俊锛�133銆�153銆�189銆�180銆�181銆�177
+        pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
+        matchers = pattern.findall(text)
+        if matchers:
+            #print('telephone')
+            for matcher in matchers:
+                text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
+        # 鍥鸿瘽
+        pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('fixed telephone')
+            for matcher in matchers:
+                text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
+
+        # 瑙勮寖鍖栧垎鏁�
+        pattern = re.compile(r"(\d+/\d+)")
+        matchers = pattern.findall(text)
+        if matchers:
+            #print('fraction')
+            for matcher in matchers:
+                text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
+
+        # 瑙勮寖鍖栫櫨鍒嗘暟
+        text = text.replace('锛�', '%')
+        pattern = re.compile(r"(\d+(\.\d+)?%)")
+        matchers = pattern.findall(text)
+        if matchers:
+            #print('percentage')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
+
+        # 瑙勮寖鍖栫函鏁�+閲忚瘝
+        pattern = re.compile(r"(\d+(\.\d+)?)[澶氫綑鍑燷?" + COM_QUANTIFIERS)
+        matchers = pattern.findall(text)
+        if matchers:
+            #print('cardinal+quantifier')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+        # 瑙勮寖鍖栨暟瀛楃紪鍙�
+        pattern = re.compile(r"(\d{4,32})")
+        matchers = pattern.findall(text)
+        if matchers:
+            #print('digit')
+            for matcher in matchers:
+                text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
+
+        # 瑙勮寖鍖栫函鏁�
+        pattern = re.compile(r"(\d+(\.\d+)?)")
+        matchers = pattern.findall(text)
+        if matchers:
+            #print('cardinal')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+        self.norm_text = text
+        self._particular()
+
+        return self.norm_text.lstrip('^').rstrip('$')
+
+
+def nsw_test_case(raw_text):
+    print('I:' + raw_text)
+    print('O:' + NSWNormalizer(raw_text).normalize())
+    print('')
+
+
+def nsw_test():
+    nsw_test_case('鍥鸿瘽锛�0595-23865596鎴�23880880銆�')
+    nsw_test_case('鍥鸿瘽锛�0595-23865596鎴�23880880銆�')
+    nsw_test_case('鎵嬫満锛�+86 19859213959鎴�15659451527銆�')
+    nsw_test_case('鍒嗘暟锛�32477/76391銆�')
+    nsw_test_case('鐧惧垎鏁帮細80.03%銆�')
+    nsw_test_case('缂栧彿锛�31520181154418銆�')
+    nsw_test_case('绾暟锛�2983.07鍏嬫垨12345.60绫炽��')
+    nsw_test_case('鏃ユ湡锛�1999骞�2鏈�20鏃ユ垨09骞�3鏈�15鍙枫��')
+    nsw_test_case('閲戦挶锛�12鍧�5锛�34.5鍏冿紝20.1涓�')
+    nsw_test_case('鐗规畩锛歄2O鎴朆2C銆�')
+    nsw_test_case('3456涓囧惃')
+    nsw_test_case('2938涓�')
+    nsw_test_case('938')
+    nsw_test_case('浠婂ぉ鍚冧簡115涓皬绗煎寘231涓澶�')
+    nsw_test_case('鏈�62锛呯殑姒傜巼')
+
+
+if __name__ == '__main__':
+    #nsw_test()
+
+    p = argparse.ArgumentParser()
+    p.add_argument('ifile', help='input filename, assume utf-8 encoding')
+    p.add_argument('ofile', help='output filename')
+    p.add_argument('--to_upper', action='store_true', help='convert to upper case')
+    p.add_argument('--to_lower', action='store_true', help='convert to lower case')
+    p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
+    p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "鍛�, 鍟�"')
+    p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "杩欏効"')
+    p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
+    args = p.parse_args()
+
+    ifile = codecs.open(args.ifile, 'r', 'utf8')
+    ofile = codecs.open(args.ofile, 'w+', 'utf8')
+
+    n = 0
+    for l in ifile:
+        key = ''
+        text = ''
+        if args.has_key:
+            cols = l.split(maxsplit=1)
+            key = cols[0]
+            if len(cols) == 2:
+                text = cols[1].strip()
+            else:
+                text = ''
+        else:
+            text = l.strip()
+
+        # cases
+        if args.to_upper and args.to_lower:
+            sys.stderr.write('text norm: to_upper OR to_lower?')
+            exit(1)
+        if args.to_upper:
+            text = text.upper()
+        if args.to_lower:
+            text = text.lower()
+
+        # Filler chars removal
+        if args.remove_fillers:
+            for ch in FILLER_CHARS:
+                text = text.replace(ch, '')
+
+        if args.remove_erhua:
+            text = remove_erhua(text, ER_WHITELIST)
+
+        # NSW(Non-Standard-Word) normalization
+        text = NSWNormalizer(text).normalize()
+
+        # Punctuations removal
+        old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
+        new_chars = ' ' * len(old_chars)
+        del_chars = ''
+        text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
+
+        #
+        if args.has_key:
+            ofile.write(key + '\t' + text + '\n')
+        else:
+            ofile.write(text + '\n')
+
+        n += 1
+        if n % args.log_interval == 0:
+            sys.stderr.write("text norm: {} lines done.\n".format(n))
+
+    sys.stderr.write("text norm: {} lines done in total.\n".format(n))
+
+    ifile.close()
+    ofile.close()

--
Gitblit v1.9.1