From ee06cb9c6870d9e1579015aabfe1a84a61a5c087 Mon Sep 17 00:00:00 2001
From: 九耳 <mengzhe.cmz@alibaba-inc.com>
Date: 星期二, 28 二月 2023 18:11:12 +0800
Subject: [PATCH] punctuation:add training code, support largedataset

---
 funasr/datasets/preprocessor.py |  100 ++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 100 insertions(+), 0 deletions(-)

diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 8e86794..20a3791 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -704,3 +704,103 @@
         del data[self.split_text_name]
         return result
 
+class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: List[str] = [None],
+            token_list: List[Union[Path, str, Iterable[str]]] = [None],
+            bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: List[str] = ["text"],
+            vad_name: str = "vad_indexes",
+    ):
+        # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+        super().__init__(
+            train=train,
+            token_type=token_type[0],
+            token_list=token_list[0],
+            bpemodel=bpemodel[0],
+            text_cleaner=text_cleaner,
+            g2p_type=g2p_type,
+            unk_symbol=unk_symbol,
+            space_symbol=space_symbol,
+            non_linguistic_symbols=non_linguistic_symbols,
+            delimiter=delimiter,
+            speech_name=speech_name,
+            text_name=text_name[0],
+            rir_scp=rir_scp,
+            rir_apply_prob=rir_apply_prob,
+            noise_scp=noise_scp,
+            noise_apply_prob=noise_apply_prob,
+            noise_db_range=noise_db_range,
+            speech_volume_normalize=speech_volume_normalize,
+        )
+
+        assert (
+                len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+        ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+        self.num_tokenizer = len(token_type)
+        self.tokenizer = []
+        self.token_id_converter = []
+
+        for i in range(self.num_tokenizer):
+            if token_type[i] is not None:
+                if token_list[i] is None:
+                    raise ValueError("token_list is required if token_type is not None")
+
+                self.tokenizer.append(
+                    build_tokenizer(
+                        token_type=token_type[i],
+                        bpemodel=bpemodel[i],
+                        delimiter=delimiter,
+                        space_symbol=space_symbol,
+                        non_linguistic_symbols=non_linguistic_symbols,
+                        g2p_type=g2p_type,
+                    )
+                )
+                self.token_id_converter.append(
+                    TokenIDConverter(
+                        token_list=token_list[i],
+                        unk_symbol=unk_symbol,
+                    )
+                )
+            else:
+                self.tokenizer.append(None)
+                self.token_id_converter.append(None)
+
+        self.text_cleaner = TextCleaner(text_cleaner)
+        self.text_name = text_name  # override the text_name from CommonPreprocessor
+        self.vad_name = vad_name
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        for i in range(self.num_tokenizer):
+            text_name = self.text_name[i]
+            if text_name in data and self.tokenizer[i] is not None:
+                text = data[text_name]
+                text = self.text_cleaner(text)
+                tokens = self.tokenizer[i].text2tokens(text)
+                if "vad:" in tokens[-1]:
+                    vad = tokens[-1][4:]
+                    tokens = tokens[:-1]
+                    if len(vad) == 0:
+                        vad = -1
+                    else:
+                        vad = int(vad)
+                    data[self.vad_name] = np.array([vad], dtype=np.int64)
+                text_ints = self.token_id_converter[i].tokens2ids(tokens)
+                data[text_name] = np.array(text_ints, dtype=np.int64)

--
Gitblit v1.9.1