import pynini
|
from fun_text_processing.text_normalization.en.graph_utils import (
|
DAMO_ALPHA,
|
DAMO_DIGIT,
|
DAMO_SIGMA,
|
GraphFst,
|
delete_extra_space,
|
delete_space,
|
insert_space,
|
plurals,
|
)
|
from fun_text_processing.text_normalization.en.utils import get_abs_path
|
from pynini.lib import pynutil
|
|
|
class TelephoneFst(GraphFst):
|
"""
|
Finite state transducer for classifying telephone, and IP, and SSN which includes country code, number part and extension
|
country code optional: +***
|
number part: ***-***-****, or (***) ***-****
|
extension optional: 1-9999
|
E.g
|
+1 123-123-5678-1 -> telephone { country_code: "one" number_part: "one two three, one two three, five six seven eight" extension: "one" }
|
1-800-GO-U-HAUL -> telephone { country_code: "one" number_part: "one, eight hundred GO U HAUL" }
|
Args:
|
deterministic: if True will provide a single transduction option,
|
for False multiple transduction are generated (used for audio-based normalization)
|
"""
|
|
def __init__(self, deterministic: bool = True):
|
super().__init__(name="telephone", kind="classify", deterministic=deterministic)
|
|
add_separator = pynutil.insert(", ") # between components
|
zero = pynini.cross("0", "zero")
|
if not deterministic:
|
zero |= pynini.cross("0", pynini.union("o", "oh"))
|
digit = (
|
pynini.invert(pynini.string_file(get_abs_path("data/number/digit.tsv"))).optimize()
|
| zero
|
)
|
|
telephone_prompts = pynini.string_file(get_abs_path("data/telephone/telephone_prompt.tsv"))
|
country_code = (
|
pynini.closure(telephone_prompts + delete_extra_space, 0, 1)
|
+ pynini.closure(pynini.cross("+", "plus "), 0, 1)
|
+ pynini.closure(digit + insert_space, 0, 2)
|
+ digit
|
+ pynutil.insert(",")
|
)
|
country_code |= telephone_prompts
|
country_code = pynutil.insert('country_code: "') + country_code + pynutil.insert('"')
|
country_code = (
|
country_code + pynini.closure(pynutil.delete("-"), 0, 1) + delete_space + insert_space
|
)
|
|
area_part_default = pynini.closure(digit + insert_space, 2, 2) + digit
|
area_part = pynini.cross("800", "eight hundred") | pynini.compose(
|
pynini.difference(DAMO_SIGMA, "800"), area_part_default
|
)
|
|
area_part = (
|
(area_part + (pynutil.delete("-") | pynutil.delete(".")))
|
| (
|
pynutil.delete("(")
|
+ area_part
|
+ (
|
(pynutil.delete(")") + pynini.closure(pynutil.delete(" "), 0, 1))
|
| pynutil.delete(")-")
|
)
|
)
|
) + add_separator
|
|
del_separator = pynini.closure(pynini.union("-", " ", "."), 0, 1)
|
number_length = ((DAMO_DIGIT + del_separator) | (DAMO_ALPHA + del_separator)) ** 7
|
number_words = pynini.closure(
|
(DAMO_DIGIT @ digit) + (insert_space | (pynini.cross("-", ", ")))
|
| DAMO_ALPHA
|
| (DAMO_ALPHA + pynini.cross("-", " "))
|
)
|
number_words |= pynini.closure(
|
(DAMO_DIGIT @ digit) + (insert_space | (pynini.cross(".", ", ")))
|
| DAMO_ALPHA
|
| (DAMO_ALPHA + pynini.cross(".", " "))
|
)
|
number_words = pynini.compose(number_length, number_words)
|
number_part = area_part + number_words
|
number_part = pynutil.insert('number_part: "') + number_part + pynutil.insert('"')
|
extension = (
|
pynutil.insert('extension: "')
|
+ pynini.closure(digit + insert_space, 0, 3)
|
+ digit
|
+ pynutil.insert('"')
|
)
|
extension = pynini.closure(insert_space + extension, 0, 1)
|
|
graph = plurals._priority_union(
|
country_code + number_part, number_part, DAMO_SIGMA
|
).optimize()
|
graph = plurals._priority_union(
|
country_code + number_part + extension, graph, DAMO_SIGMA
|
).optimize()
|
graph = plurals._priority_union(number_part + extension, graph, DAMO_SIGMA).optimize()
|
|
# ip
|
ip_prompts = pynini.string_file(get_abs_path("data/telephone/ip_prompt.tsv"))
|
digit_to_str_graph = digit + pynini.closure(pynutil.insert(" ") + digit, 0, 2)
|
ip_graph = digit_to_str_graph + (pynini.cross(".", " dot ") + digit_to_str_graph) ** 3
|
graph |= (
|
pynini.closure(
|
pynutil.insert('country_code: "')
|
+ ip_prompts
|
+ pynutil.insert('"')
|
+ delete_extra_space,
|
0,
|
1,
|
)
|
+ pynutil.insert('number_part: "')
|
+ ip_graph.optimize()
|
+ pynutil.insert('"')
|
)
|
# ssn
|
ssn_prompts = pynini.string_file(get_abs_path("data/telephone/ssn_prompt.tsv"))
|
three_digit_part = digit + (pynutil.insert(" ") + digit) ** 2
|
two_digit_part = digit + pynutil.insert(" ") + digit
|
four_digit_part = digit + (pynutil.insert(" ") + digit) ** 3
|
ssn_separator = pynini.cross("-", ", ")
|
ssn_graph = (
|
three_digit_part + ssn_separator + two_digit_part + ssn_separator + four_digit_part
|
)
|
|
graph |= (
|
pynini.closure(
|
pynutil.insert('country_code: "')
|
+ ssn_prompts
|
+ pynutil.insert('"')
|
+ delete_extra_space,
|
0,
|
1,
|
)
|
+ pynutil.insert('number_part: "')
|
+ ssn_graph.optimize()
|
+ pynutil.insert('"')
|
)
|
|
final_graph = self.add_tokens(graph)
|
self.fst = final_graph.optimize()
|