From fc246ab820cf57ba08afbe3cbeb4d471036eb83c Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 07 十二月 2023 00:27:00 +0800
Subject: [PATCH] funasr2
---
funasr/tokenizer/abs_tokenizer.py | 61 +++++++++++++++---------------
funasr/tokenizer/build_tokenizer.py | 2
2 files changed, 32 insertions(+), 31 deletions(-)
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index ffb6b76..d2fc3f0 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -21,42 +21,43 @@
class BaseTokenizer(ABC):
- def __init__(self, token_list: Union[Path, str, Iterable[str]],
+ def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
unk_symbol: str = "<unk>",
**kwargs,
):
- if isinstance(token_list, (Path, str)):
- token_list = Path(token_list)
- self.token_list_repr = str(token_list)
- self.token_list: List[str] = []
+ if token_list is not None:
+ if isinstance(token_list, (Path, str)):
+ token_list = Path(token_list)
+ self.token_list_repr = str(token_list)
+ self.token_list: List[str] = []
+
+ with token_list.open("r", encoding="utf-8") as f:
+ for idx, line in enumerate(f):
+ line = line.rstrip()
+ self.token_list.append(line)
- with token_list.open("r", encoding="utf-8") as f:
- for idx, line in enumerate(f):
- line = line.rstrip()
- self.token_list.append(line)
-
- else:
- self.token_list: List[str] = list(token_list)
- self.token_list_repr = ""
+ else:
+ self.token_list: List[str] = list(token_list)
+ self.token_list_repr = ""
+ for i, t in enumerate(self.token_list):
+ if i == 3:
+ break
+ self.token_list_repr += f"{t}, "
+ self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
+
+ self.token2id: Dict[str, int] = {}
for i, t in enumerate(self.token_list):
- if i == 3:
- break
- self.token_list_repr += f"{t}, "
- self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
-
- self.token2id: Dict[str, int] = {}
- for i, t in enumerate(self.token_list):
- if t in self.token2id:
- raise RuntimeError(f'Symbol "{t}" is duplicated')
- self.token2id[t] = i
-
- self.unk_symbol = unk_symbol
- if self.unk_symbol not in self.token2id:
- raise RuntimeError(
- f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
- )
- self.unk_id = self.token2id[self.unk_symbol]
+ if t in self.token2id:
+ raise RuntimeError(f'Symbol "{t}" is duplicated')
+ self.token2id[t] = i
+
+ self.unk_symbol = unk_symbol
+ if self.unk_symbol not in self.token2id:
+ raise RuntimeError(
+ f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
+ )
+ self.unk_id = self.token2id[self.unk_symbol]
def encode(self, text):
tokens = self.text2tokens(text)
diff --git a/funasr/tokenizer/build_tokenizer.py b/funasr/tokenizer/build_tokenizer.py
index 1dc17da..05db6a6 100644
--- a/funasr/tokenizer/build_tokenizer.py
+++ b/funasr/tokenizer/build_tokenizer.py
@@ -29,7 +29,7 @@
delimiter: str = None,
g2p_type: str = None,
**kwargs,
-) -> AbsTokenizer:
+):
"""A helper function to instantiate Tokenizer"""
if token_type == "bpe":
if bpemodel is None:
--
Gitblit v1.9.1