From 86d65112abcaa6b41352c7d2774e82c5c5618b6c Mon Sep 17 00:00:00 2001
From: 九耳 <mengzhe.cmz@alibaba-inc.com>
Date: 星期日, 05 二月 2023 10:48:12 +0800
Subject: [PATCH] fix

---
 funasr/datasets/preprocessor.py |   93 ++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 93 insertions(+), 0 deletions(-)

diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 75bee86..10fbccb 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -538,3 +538,96 @@
                 data[text_name] = np.array(text_ints, dtype=np.int64)
         assert check_return_type(data)
         return data
+
+class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: str = None,
+            token_list: Union[Path, str, Iterable[str]] = None,
+            bpemodel: Union[Path, str, Iterable[str]] = None,
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: str = "text",
+            split_text_name: str = "split_text",
+            split_with_space: bool = False,
+            seg_dict_file: str = None,
+    ):
+        super().__init__(
+            train=train,
+            # Force to use word.
+            token_type="word",
+            token_list=token_list,
+            bpemodel=bpemodel,
+            text_cleaner=text_cleaner,
+            g2p_type=g2p_type,
+            unk_symbol=unk_symbol,
+            space_symbol=space_symbol,
+            non_linguistic_symbols=non_linguistic_symbols,
+            delimiter=delimiter,
+            speech_name=speech_name,
+            text_name=text_name,
+            rir_scp=rir_scp,
+            rir_apply_prob=rir_apply_prob,
+            noise_scp=noise_scp,
+            noise_apply_prob=noise_apply_prob,
+            noise_db_range=noise_db_range,
+            speech_volume_normalize=speech_volume_normalize,
+            split_with_space=split_with_space,
+            seg_dict_file=seg_dict_file,
+        )
+        # The data field name for split text.
+        self.split_text_name = split_text_name
+
+    @classmethod
+    def split_words(cls, text: str):
+        words = []
+        segs = text.split()
+        for seg in segs:
+            # There is no space in seg.
+            current_word = ""
+            for c in seg:
+                if len(c.encode()) == 1:
+                    # This is an ASCII char.
+                    current_word += c
+                else:
+                    # This is a Chinese char.
+                    if len(current_word) > 0:
+                        words.append(current_word)
+                        current_word = ""
+                    words.append(c)
+            if len(current_word) > 0:
+                words.append(current_word)
+        return words
+
+    def __call__(
+            self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
+    ) -> Dict[str, Union[list, np.ndarray]]:
+        assert check_argument_types()
+        # Split words.
+        if isinstance(data[self.text_name], str):
+            split_text = self.split_words(data[self.text_name])
+        else:
+            split_text = data[self.text_name]
+        data[self.text_name] = " ".join(split_text)
+        data = self._speech_process(data)
+        data = self._text_process(data)
+        data[self.split_text_name] = split_text
+        return data
+
+    def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
+        result = data[self.split_text_name]
+        del data[self.split_text_name]
+        return result
+

--
Gitblit v1.9.1