kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
examples/aishell/paraformer/utils/textnorm_zh.py
@@ -14,49 +14,58 @@
# ================================================================================ #
#                                    basic constant
# ================================================================================ #
CHINESE_DIGIS = u'零一二三四五六七八九'
BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖'
BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖'
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万'
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬'
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载'
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載'
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万'
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬'
CHINESE_DIGIS = "零一二三四五六七八九"
BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
ZERO_ALT = u'〇'
ONE_ALT = u'幺'
TWO_ALTS = [u'两', u'兩']
ZERO_ALT = "〇"
ONE_ALT = "幺"
TWO_ALTS = ["两", "兩"]
POSITIVE = [u'正', u'正']
NEGATIVE = [u'负', u'負']
POINT = [u'点', u'點']
POSITIVE = ["正", "正"]
NEGATIVE = ["负", "負"]
POINT = ["点", "點"]
# PLUS = [u'加', u'加']
# SIL = [u'杠', u'槓']
FILLER_CHARS = ['呃', '啊']
ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \
             '胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \
             '儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \
             '佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)'
FILLER_CHARS = ["呃", "啊"]
ER_WHITELIST = (
    "(儿女|儿子|儿孙|女儿|儿媳|妻儿|"
    "胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|"
    "儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|"
    "佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)"
)
# 中文数字系统类型
NUMBERING_TYPES = ['low', 'mid', 'high']
NUMBERING_TYPES = ["low", "mid", "high"]
CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \
                 '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)'
CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \
                  '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \
                  '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \
                  '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \
                  '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \
                  '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)'
CURRENCY_NAMES = (
    "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
    "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
)
CURRENCY_UNITS = (
    "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
)
COM_QUANTIFIERS = (
    "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
    "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
    "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
    "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
    "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
    "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)"
)
# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
CHINESE_PUNC_STOP = '!?。。'
CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏'
CHINESE_PUNC_STOP = "!?。。"
CHINESE_PUNC_NON_STOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏"
CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
# ================================================================================ #
#                                    basic class
@@ -72,7 +81,7 @@
    def __init__(self, simplified, traditional):
        self.simplified = simplified
        self.traditional = traditional
        #self.__repr__ = self.__str__
        # self.__repr__ = self.__str__
    def __str__(self):
        return self.simplified or self.traditional or None
@@ -95,26 +104,49 @@
        self.big_t = big_t
    def __str__(self):
        return '10^{}'.format(self.power)
        return "10^{}".format(self.power)
    @classmethod
    def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
        if small_unit:
            return ChineseNumberUnit(power=index + 1,
                                     simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
            return ChineseNumberUnit(
                power=index + 1,
                simplified=value[0],
                traditional=value[1],
                big_s=value[1],
                big_t=value[1],
            )
        elif numbering_type == NUMBERING_TYPES[0]:
            return ChineseNumberUnit(power=index + 8,
                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
            return ChineseNumberUnit(
                power=index + 8,
                simplified=value[0],
                traditional=value[1],
                big_s=value[0],
                big_t=value[1],
            )
        elif numbering_type == NUMBERING_TYPES[1]:
            return ChineseNumberUnit(power=(index + 2) * 4,
                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
            return ChineseNumberUnit(
                power=(index + 2) * 4,
                simplified=value[0],
                traditional=value[1],
                big_s=value[0],
                big_t=value[1],
            )
        elif numbering_type == NUMBERING_TYPES[2]:
            return ChineseNumberUnit(power=pow(2, index + 3),
                                     simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
            return ChineseNumberUnit(
                power=pow(2, index + 3),
                simplified=value[0],
                traditional=value[1],
                big_s=value[0],
                big_t=value[1],
            )
        else:
            raise ValueError(
                'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
                "Counting type should be in {0} ({1} provided).".format(
                    NUMBERING_TYPES, numbering_type
                )
            )
class ChineseNumberDigit(ChineseChar):
@@ -158,6 +190,7 @@
    """
    中文数字系统
    """
    pass
@@ -207,27 +240,27 @@
    # chinese number units of '亿' and larger
    all_larger_units = zip(
        LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
    larger_units = [CNU.create(i, v, numbering_type, False)
                    for i, v in enumerate(all_larger_units)]
        LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
    )
    larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)]
    # chinese number units of '十, 百, 千, 万'
    all_smaller_units = zip(
        SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
    smaller_units = [CNU.create(i, v, small_unit=True)
                     for i, v in enumerate(all_smaller_units)]
        SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
    )
    smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)]
    # digis
    chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
                        BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
    chinese_digis = zip(
        CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL
    )
    digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
    digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
    digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
    digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
    # symbols
    positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
    negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
    point_cn = CM(POINT[0], POINT[1], '.', lambda x,
                  y: float(str(x) + '.' + str(y)))
    positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
    negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
    point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
    # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
    system = NumberSystem()
    system.units = smaller_units + larger_units
@@ -251,13 +284,14 @@
                return m
    def string2symbols(chinese_string, system):
        int_string, dec_string = chinese_string, ''
        int_string, dec_string = chinese_string, ""
        for p in [system.math.point.simplified, system.math.point.traditional]:
            if p in chinese_string:
                int_string, dec_string = chinese_string.split(p)
                break
        return [get_symbol(c, system) for c in int_string], \
               [get_symbol(c, system) for c in dec_string]
        return [get_symbol(c, system) for c in int_string], [
            get_symbol(c, system) for c in dec_string
        ]
    def correct_symbols(integer_symbols, system):
        """
@@ -271,8 +305,7 @@
        if len(integer_symbols) > 1:
            if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
                integer_symbols.append(
                    CNU(integer_symbols[-2].power - 1, None, None, None, None))
                integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None))
        result = []
        unit_count = 0
@@ -288,9 +321,13 @@
                result.append(current_unit)
            elif unit_count > 1:
                for i in range(len(result)):
                    if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
                        result[-i - 1] = CNU(result[-i - 1].power +
                                             current_unit.power, None, None, None, None)
                    if (
                        isinstance(result[-i - 1], CNU)
                        and result[-i - 1].power < current_unit.power
                    ):
                        result[-i - 1] = CNU(
                            result[-i - 1].power + current_unit.power, None, None, None, None
                        )
        return result
    def compute_value(integer_symbols):
@@ -307,8 +344,7 @@
            elif isinstance(s, CNU):
                value[-1] *= pow(10, s.power)
                if s.power > last_power:
                    value[:-1] = list(map(lambda v: v *
                                                    pow(10, s.power), value[:-1]))
                    value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
                    last_power = s.power
                value.append(0)
        return sum(value)
@@ -317,20 +353,28 @@
    int_part, dec_part = string2symbols(chinese_string, system)
    int_part = correct_symbols(int_part, system)
    int_str = str(compute_value(int_part))
    dec_str = ''.join([str(d.value) for d in dec_part])
    dec_str = "".join([str(d.value) for d in dec_part])
    if dec_part:
        return '{0}.{1}'.format(int_str, dec_str)
        return "{0}.{1}".format(int_str, dec_str)
    else:
        return int_str
def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
            traditional=False, alt_zero=False, alt_one=False, alt_two=True,
            use_zeros=True, use_units=True):
def num2chn(
    number_string,
    numbering_type=NUMBERING_TYPES[1],
    big=False,
    traditional=False,
    alt_zero=False,
    alt_one=False,
    alt_two=True,
    use_zeros=True,
    use_units=True,
):
    def get_value(value_string, use_zeros=True):
        striped_string = value_string.lstrip('0')
        striped_string = value_string.lstrip("0")
        # record nothing if all zeros
        if not striped_string:
@@ -345,14 +389,17 @@
        # recursively record multiple digits
        else:
            result_unit = next(u for u in reversed(
                system.units) if u.power < len(striped_string))
            result_string = value_string[:-result_unit.power]
            return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
            result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string))
            result_string = value_string[: -result_unit.power]
            return (
                get_value(result_string)
                + [result_unit]
                + get_value(striped_string[-result_unit.power :])
            )
    system = create_system(numbering_type)
    int_dec = number_string.split('.')
    int_dec = number_string.split(".")
    if len(int_dec) == 1:
        int_string = int_dec[0]
        dec_string = ""
@@ -361,7 +408,8 @@
        dec_string = int_dec[1]
    else:
        raise ValueError(
            "invalid input num string with more than one dot: {}".format(number_string))
            "invalid input num string with more than one dot: {}".format(number_string)
        )
    if use_units and len(int_string) > 1:
        result_symbols = get_value(int_string)
@@ -372,51 +420,62 @@
        result_symbols += [system.math.point] + dec_symbols
    if alt_two:
        liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
                    system.digits[2].big_s, system.digits[2].big_t)
        liang = CND(
            2,
            system.digits[2].alt_s,
            system.digits[2].alt_t,
            system.digits[2].big_s,
            system.digits[2].big_t,
        )
        for i, v in enumerate(result_symbols):
            if isinstance(v, CND) and v.value == 2:
                next_symbol = result_symbols[i +
                                             1] if i < len(result_symbols) - 1 else None
                next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
                previous_symbol = result_symbols[i - 1] if i > 0 else None
                if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
                    if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
                    if next_symbol.power != 1 and (
                        (previous_symbol is None) or (previous_symbol.power != 1)
                    ):
                        result_symbols[i] = liang
    # if big is True, '两' will not be used and `alt_two` has no impact on output
    if big:
        attr_name = 'big_'
        attr_name = "big_"
        if traditional:
            attr_name += 't'
            attr_name += "t"
        else:
            attr_name += 's'
            attr_name += "s"
    else:
        if traditional:
            attr_name = 'traditional'
            attr_name = "traditional"
        else:
            attr_name = 'simplified'
            attr_name = "simplified"
    result = ''.join([getattr(s, attr_name) for s in result_symbols])
    result = "".join([getattr(s, attr_name) for s in result_symbols])
    # if not use_zeros:
    #     result = result.strip(getattr(system.digits[0], attr_name))
    if alt_zero:
        result = result.replace(
            getattr(system.digits[0], attr_name), system.digits[0].alt_s)
        result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s)
    if alt_one:
        result = result.replace(
            getattr(system.digits[1], attr_name), system.digits[1].alt_s)
        result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s)
    for i, p in enumerate(POINT):
        if result.startswith(p):
            return CHINESE_DIGIS[0] + result
    # ^10, 11, .., 19
    if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
                                          SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
            result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
    if (
        len(result) >= 2
        and result[1]
        in [
            SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
            SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
        ]
        and result[0]
        in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]
    ):
        result = result[1:]
    return result
@@ -439,6 +498,7 @@
    def cardinal2chntext(self):
        return num2chn(self.cardinal)
class Digit:
    """
@@ -476,17 +536,17 @@
    def telephone2chntext(self, fixed=False):
        if fixed:
            sil_parts = self.telephone.split('-')
            self.raw_chntext = '<SIL>'.join([
                num2chn(part, alt_two=False, use_units=False) for part in sil_parts
            ])
            self.chntext = self.raw_chntext.replace('<SIL>', '')
            sil_parts = self.telephone.split("-")
            self.raw_chntext = "<SIL>".join(
                [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
            )
            self.chntext = self.raw_chntext.replace("<SIL>", "")
        else:
            sp_parts = self.telephone.strip('+').split()
            self.raw_chntext = '<SP>'.join([
                num2chn(part, alt_two=False, use_units=False) for part in sp_parts
            ])
            self.chntext = self.raw_chntext.replace('<SP>', '')
            sp_parts = self.telephone.strip("+").split()
            self.raw_chntext = "<SP>".join(
                [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
            )
            self.chntext = self.raw_chntext.replace("<SP>", "")
        return self.chntext
@@ -500,12 +560,12 @@
        self.chntext = chntext
    def chntext2fraction(self):
        denominator, numerator = self.chntext.split('分之')
        return chn2num(numerator) + '/' + chn2num(denominator)
        denominator, numerator = self.chntext.split("分之")
        return chn2num(numerator) + "/" + chn2num(denominator)
    def fraction2chntext(self):
        numerator, denominator = self.fraction.split('/')
        return num2chn(denominator) + '分之' + num2chn(numerator)
        numerator, denominator = self.fraction.split("/")
        return num2chn(denominator) + "分之" + num2chn(numerator)
class Date:
@@ -544,23 +604,23 @@
    def date2chntext(self):
        date = self.date
        try:
            year, other = date.strip().split('年', 1)
            year = Digit(digit=year).digit2chntext() + '年'
            year, other = date.strip().split("年", 1)
            year = Digit(digit=year).digit2chntext() + "年"
        except ValueError:
            other = date
            year = ''
            year = ""
        if other:
            try:
                month, day = other.strip().split('月', 1)
                month = Cardinal(cardinal=month).cardinal2chntext() + '月'
                month, day = other.strip().split("月", 1)
                month = Cardinal(cardinal=month).cardinal2chntext() + "月"
            except ValueError:
                day = date
                month = ''
                month = ""
            if day:
                day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
        else:
            month = ''
            day = ''
            month = ""
            day = ""
        chntext = year + month + day
        self.chntext = chntext
        return self.chntext
@@ -580,7 +640,7 @@
    def money2chntext(self):
        money = self.money
        pattern = re.compile(r'(\d+(\.\d+)?)')
        pattern = re.compile(r"(\d+(\.\d+)?)")
        matchers = pattern.findall(money)
        if matchers:
            for matcher in matchers:
@@ -599,10 +659,10 @@
        self.chntext = chntext
    def chntext2percentage(self):
        return chn2num(self.chntext.strip().strip('百分之')) + '%'
        return chn2num(self.chntext.strip().strip("百分之")) + "%"
    def percentage2chntext(self):
        return '百分之' + num2chn(self.percentage.strip().strip('%'))
        return "百分之" + num2chn(self.percentage.strip().strip("%"))
def remove_erhua(text, er_whitelist):
@@ -612,9 +672,9 @@
    """
    er_pattern = re.compile(er_whitelist)
    new_str=''
    while re.search('儿',text):
        a = re.search('儿',text).span()
    new_str = ""
    while re.search("儿", text):
        a = re.search("儿", text).span()
        remove_er_flag = 0
        if er_pattern.search(text):
@@ -622,23 +682,24 @@
            if b[0] <= a[0]:
                remove_er_flag = 1
        if remove_er_flag == 0 :
            new_str = new_str + text[0:a[0]]
            text = text[a[1]:]
        if remove_er_flag == 0:
            new_str = new_str + text[0 : a[0]]
            text = text[a[1] :]
        else:
            new_str = new_str + text[0:b[1]]
            text = text[b[1]:]
            new_str = new_str + text[0 : b[1]]
            text = text[b[1] :]
    text = new_str + text
    return text
# ================================================================================ #
#                            NSW Normalizer
# ================================================================================ #
class NSWNormalizer:
    def __init__(self, raw_text):
        self.raw_text = '^' + raw_text + '$'
        self.norm_text = ''
        self.raw_text = "^" + raw_text + "$"
        self.norm_text = ""
    def _particular(self):
        text = self.norm_text
@@ -647,7 +708,7 @@
        if matchers:
            # print('particular')
            for matcher in matchers:
                text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1)
                text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
        self.norm_text = text
        return self.norm_text
@@ -658,15 +719,17 @@
        pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
        matchers = pattern.findall(text)
        if matchers:
            #print('date')
            # print('date')
            for matcher in matchers:
                text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
        # 规范化金钱
        pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
        pattern = re.compile(
            r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)"
        )
        matchers = pattern.findall(text)
        if matchers:
            #print('money')
            # print('money')
            for matcher in matchers:
                text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
@@ -679,39 +742,45 @@
        pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
        matchers = pattern.findall(text)
        if matchers:
            #print('telephone')
            # print('telephone')
            for matcher in matchers:
                text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
                text = text.replace(
                    matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
                )
        # 固话
        pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
        matchers = pattern.findall(text)
        if matchers:
            # print('fixed telephone')
            for matcher in matchers:
                text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
                text = text.replace(
                    matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1
                )
        # 规范化分数
        pattern = re.compile(r"(\d+/\d+)")
        matchers = pattern.findall(text)
        if matchers:
            #print('fraction')
            # print('fraction')
            for matcher in matchers:
                text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
        # 规范化百分数
        text = text.replace('%', '%')
        text = text.replace("%", "%")
        pattern = re.compile(r"(\d+(\.\d+)?%)")
        matchers = pattern.findall(text)
        if matchers:
            #print('percentage')
            # print('percentage')
            for matcher in matchers:
                text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
                text = text.replace(
                    matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1
                )
        # 规范化纯数+量词
        pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
        matchers = pattern.findall(text)
        if matchers:
            #print('cardinal+quantifier')
            # print('cardinal+quantifier')
            for matcher in matchers:
                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
@@ -719,7 +788,7 @@
        pattern = re.compile(r"(\d{4,32})")
        matchers = pattern.findall(text)
        if matchers:
            #print('digit')
            # print('digit')
            for matcher in matchers:
                text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
@@ -727,74 +796,82 @@
        pattern = re.compile(r"(\d+(\.\d+)?)")
        matchers = pattern.findall(text)
        if matchers:
            #print('cardinal')
            # print('cardinal')
            for matcher in matchers:
                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
        self.norm_text = text
        self._particular()
        return self.norm_text.lstrip('^').rstrip('$')
        return self.norm_text.lstrip("^").rstrip("$")
def nsw_test_case(raw_text):
    print('I:' + raw_text)
    print('O:' + NSWNormalizer(raw_text).normalize())
    print('')
    print("I:" + raw_text)
    print("O:" + NSWNormalizer(raw_text).normalize())
    print("")
def nsw_test():
    nsw_test_case('固话:0595-23865596或23880880。')
    nsw_test_case('固话:0595-23865596或23880880。')
    nsw_test_case('手机:+86 19859213959或15659451527。')
    nsw_test_case('分数:32477/76391。')
    nsw_test_case('百分数:80.03%。')
    nsw_test_case('编号:31520181154418。')
    nsw_test_case('纯数:2983.07克或12345.60米。')
    nsw_test_case('日期:1999年2月20日或09年3月15号。')
    nsw_test_case('金钱:12块5,34.5元,20.1万')
    nsw_test_case('特殊:O2O或B2C。')
    nsw_test_case('3456万吨')
    nsw_test_case('2938个')
    nsw_test_case('938')
    nsw_test_case('今天吃了115个小笼包231个馒头')
    nsw_test_case('有62%的概率')
    nsw_test_case("固话:0595-23865596或23880880。")
    nsw_test_case("固话:0595-23865596或23880880。")
    nsw_test_case("手机:+86 19859213959或15659451527。")
    nsw_test_case("分数:32477/76391。")
    nsw_test_case("百分数:80.03%。")
    nsw_test_case("编号:31520181154418。")
    nsw_test_case("纯数:2983.07克或12345.60米。")
    nsw_test_case("日期:1999年2月20日或09年3月15号。")
    nsw_test_case("金钱:12块5,34.5元,20.1万")
    nsw_test_case("特殊:O2O或B2C。")
    nsw_test_case("3456万吨")
    nsw_test_case("2938个")
    nsw_test_case("938")
    nsw_test_case("今天吃了115个小笼包231个馒头")
    nsw_test_case("有62%的概率")
if __name__ == '__main__':
    #nsw_test()
if __name__ == "__main__":
    # nsw_test()
    p = argparse.ArgumentParser()
    p.add_argument('ifile', help='input filename, assume utf-8 encoding')
    p.add_argument('ofile', help='output filename')
    p.add_argument('--to_upper', action='store_true', help='convert to upper case')
    p.add_argument('--to_lower', action='store_true', help='convert to lower case')
    p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
    p.add_argument('--remove_fillers', type=bool, default=True, help='remove filler chars such as "呃, 啊"')
    p.add_argument('--remove_erhua', type=bool, default=True, help='remove erhua chars such as "这儿"')
    p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
    p.add_argument("ifile", help="input filename, assume utf-8 encoding")
    p.add_argument("ofile", help="output filename")
    p.add_argument("--to_upper", action="store_true", help="convert to upper case")
    p.add_argument("--to_lower", action="store_true", help="convert to lower case")
    p.add_argument(
        "--has_key", action="store_true", help="input text has Kaldi's key as first field."
    )
    p.add_argument(
        "--remove_fillers", type=bool, default=True, help='remove filler chars such as "呃, 啊"'
    )
    p.add_argument(
        "--remove_erhua", type=bool, default=True, help='remove erhua chars such as "这儿"'
    )
    p.add_argument(
        "--log_interval", type=int, default=10000, help="log interval in number of processed lines"
    )
    args = p.parse_args()
    ifile = codecs.open(args.ifile, 'r', 'utf8')
    ofile = codecs.open(args.ofile, 'w+', 'utf8')
    ifile = codecs.open(args.ifile, "r", "utf8")
    ofile = codecs.open(args.ofile, "w+", "utf8")
    n = 0
    for l in ifile:
        key = ''
        text = ''
        key = ""
        text = ""
        if args.has_key:
            cols = l.split(maxsplit=1)
            key = cols[0]
            if len(cols) == 2:
                text = cols[1].strip()
            else:
                text = ''
                text = ""
        else:
            text = l.strip()
        # cases
        if args.to_upper and args.to_lower:
            sys.stderr.write('text norm: to_upper OR to_lower?')
            sys.stderr.write("text norm: to_upper OR to_lower?")
            exit(1)
        if args.to_upper:
            text = text.upper()
@@ -804,7 +881,7 @@
        # Filler chars removal
        if args.remove_fillers:
            for ch in FILLER_CHARS:
                text = text.replace(ch, '')
                text = text.replace(ch, "")
        if args.remove_erhua:
            text = remove_erhua(text, ER_WHITELIST)
@@ -813,16 +890,16 @@
        text = NSWNormalizer(text).normalize()
        # Punctuations removal
        old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
        new_chars = ' ' * len(old_chars)
        del_chars = ''
        old_chars = CHINESE_PUNC_LIST + string.punctuation  # includes all CN and EN punctuations
        new_chars = " " * len(old_chars)
        del_chars = ""
        text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
        #
        if args.has_key:
            ofile.write(key + '\t' + text + '\n')
            ofile.write(key + "\t" + text + "\n")
        else:
            ofile.write(text + '\n')
            ofile.write(text + "\n")
        n += 1
        if n % args.log_interval == 0: