From fc246ab820cf57ba08afbe3cbeb4d471036eb83c Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 07 十二月 2023 00:27:00 +0800
Subject: [PATCH] funasr2

---
 funasr/tokenizer/abs_tokenizer.py   |   61 +++++++++++++++---------------
 funasr/tokenizer/build_tokenizer.py |    2 
 2 files changed, 32 insertions(+), 31 deletions(-)

diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index ffb6b76..d2fc3f0 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -21,42 +21,43 @@
 
 
 class BaseTokenizer(ABC):
-    def __init__(self, token_list: Union[Path, str, Iterable[str]],
+    def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
                  unk_symbol: str = "<unk>",
                  **kwargs,
                  ):
         
-        if isinstance(token_list, (Path, str)):
-            token_list = Path(token_list)
-            self.token_list_repr = str(token_list)
-            self.token_list: List[str] = []
+        if token_list is not None:
+            if isinstance(token_list, (Path, str)):
+                token_list = Path(token_list)
+                self.token_list_repr = str(token_list)
+                self.token_list: List[str] = []
+                
+                with token_list.open("r", encoding="utf-8") as f:
+                    for idx, line in enumerate(f):
+                        line = line.rstrip()
+                        self.token_list.append(line)
             
-            with token_list.open("r", encoding="utf-8") as f:
-                for idx, line in enumerate(f):
-                    line = line.rstrip()
-                    self.token_list.append(line)
-        
-        else:
-            self.token_list: List[str] = list(token_list)
-            self.token_list_repr = ""
+            else:
+                self.token_list: List[str] = list(token_list)
+                self.token_list_repr = ""
+                for i, t in enumerate(self.token_list):
+                    if i == 3:
+                        break
+                    self.token_list_repr += f"{t}, "
+                self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
+            
+            self.token2id: Dict[str, int] = {}
             for i, t in enumerate(self.token_list):
-                if i == 3:
-                    break
-                self.token_list_repr += f"{t}, "
-            self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
-        
-        self.token2id: Dict[str, int] = {}
-        for i, t in enumerate(self.token_list):
-            if t in self.token2id:
-                raise RuntimeError(f'Symbol "{t}" is duplicated')
-            self.token2id[t] = i
-        
-        self.unk_symbol = unk_symbol
-        if self.unk_symbol not in self.token2id:
-            raise RuntimeError(
-                f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
-            )
-        self.unk_id = self.token2id[self.unk_symbol]
+                if t in self.token2id:
+                    raise RuntimeError(f'Symbol "{t}" is duplicated')
+                self.token2id[t] = i
+            
+            self.unk_symbol = unk_symbol
+            if self.unk_symbol not in self.token2id:
+                raise RuntimeError(
+                    f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
+                )
+            self.unk_id = self.token2id[self.unk_symbol]
     
     def encode(self, text):
         tokens = self.text2tokens(text)
diff --git a/funasr/tokenizer/build_tokenizer.py b/funasr/tokenizer/build_tokenizer.py
index 1dc17da..05db6a6 100644
--- a/funasr/tokenizer/build_tokenizer.py
+++ b/funasr/tokenizer/build_tokenizer.py
@@ -29,7 +29,7 @@
     delimiter: str = None,
     g2p_type: str = None,
     **kwargs,
-) -> AbsTokenizer:
+):
     """A helper function to instantiate Tokenizer"""
     if token_type == "bpe":
         if bpemodel is None:

--
Gitblit v1.9.1