jmwang66
2023-01-16 12a7adfdf3dd4f80b5d3a51cfc4eecc84eaa7c64
funasr/datasets/preprocessor.py
@@ -1,3 +1,4 @@
import re
from abc import ABC
from abc import abstractmethod
from pathlib import Path
@@ -27,6 +28,35 @@
        self, uid: str, data: Dict[str, Union[str, np.ndarray]]
    ) -> Dict[str, np.ndarray]:
        raise NotImplementedError
def forward_segment(text, dic):
    word_list = []
    i = 0
    while i < len(text):
        longest_word = text[i]
        for j in range(i + 1, len(text) + 1):
            word = text[i:j]
            if word in dic:
                if len(word) > len(longest_word):
                    longest_word = word
        word_list.append(longest_word)
        i += len(longest_word)
    return word_list
def seg_tokenize(txt, seg_dict):
    out_txt = ""
    pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
    for word in txt:
        if pattern.match(word):
            if word in seg_dict:
                out_txt += seg_dict[word] + " "
            else:
                out_txt += "<unk>" + " "
        else:
            continue
    return out_txt.strip().split()
def framing(
@@ -146,6 +176,7 @@
        speech_name: str = "speech",
        text_name: str = "text",
        split_with_space: bool = False,
            seg_dict_file: str = None,
    ):
        super().__init__(train)
        self.train = train
@@ -155,6 +186,16 @@
        self.rir_apply_prob = rir_apply_prob
        self.noise_apply_prob = noise_apply_prob
        self.split_with_space = split_with_space
        self.seg_dict = None
        if seg_dict_file is not None:
            self.seg_dict = {}
            with open(seg_dict_file) as f:
                lines = f.readlines()
            for line in lines:
                s = line.strip().split()
                key = s[0]
                value = s[1:]
                self.seg_dict[key] = " ".join(value)
        if token_type is not None:
            if token_list is None:
@@ -312,6 +353,9 @@
            text = self.text_cleaner(text)
            if self.split_with_space:
                tokens = text.strip().split(" ")
                if self.seg_dict is not None:
                    tokens = forward_segment("".join(tokens).lower(), self.seg_dict)
                    tokens = seg_tokenize(tokens, self.seg_dict)
            else:
                tokens = self.tokenizer.text2tokens(text)
            text_ints = self.token_id_converter.tokens2ids(tokens)