| | |
| | | |
| | | |
| | | class BaseTokenizer(ABC): |
| | | def __init__(self, token_list: Union[Path, str, Iterable[str]], |
| | | def __init__(self, token_list: Union[Path, str, Iterable[str]]=None, |
| | | unk_symbol: str = "<unk>", |
| | | **kwargs, |
| | | ): |
| | | |
| | | if isinstance(token_list, (Path, str)): |
| | | token_list = Path(token_list) |
| | | self.token_list_repr = str(token_list) |
| | | self.token_list: List[str] = [] |
| | | if token_list is not None: |
| | | if isinstance(token_list, (Path, str)): |
| | | token_list = Path(token_list) |
| | | self.token_list_repr = str(token_list) |
| | | self.token_list: List[str] = [] |
| | | |
| | | with token_list.open("r", encoding="utf-8") as f: |
| | | for idx, line in enumerate(f): |
| | | line = line.rstrip() |
| | | self.token_list.append(line) |
| | | |
| | | with token_list.open("r", encoding="utf-8") as f: |
| | | for idx, line in enumerate(f): |
| | | line = line.rstrip() |
| | | self.token_list.append(line) |
| | | |
| | | else: |
| | | self.token_list: List[str] = list(token_list) |
| | | self.token_list_repr = "" |
| | | else: |
| | | self.token_list: List[str] = list(token_list) |
| | | self.token_list_repr = "" |
| | | for i, t in enumerate(self.token_list): |
| | | if i == 3: |
| | | break |
| | | self.token_list_repr += f"{t}, " |
| | | self.token_list_repr += f"... (NVocab={(len(self.token_list))})" |
| | | |
| | | self.token2id: Dict[str, int] = {} |
| | | for i, t in enumerate(self.token_list): |
| | | if i == 3: |
| | | break |
| | | self.token_list_repr += f"{t}, " |
| | | self.token_list_repr += f"... (NVocab={(len(self.token_list))})" |
| | | |
| | | self.token2id: Dict[str, int] = {} |
| | | for i, t in enumerate(self.token_list): |
| | | if t in self.token2id: |
| | | raise RuntimeError(f'Symbol "{t}" is duplicated') |
| | | self.token2id[t] = i |
| | | |
| | | self.unk_symbol = unk_symbol |
| | | if self.unk_symbol not in self.token2id: |
| | | raise RuntimeError( |
| | | f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list" |
| | | ) |
| | | self.unk_id = self.token2id[self.unk_symbol] |
| | | if t in self.token2id: |
| | | raise RuntimeError(f'Symbol "{t}" is duplicated') |
| | | self.token2id[t] = i |
| | | |
| | | self.unk_symbol = unk_symbol |
| | | if self.unk_symbol not in self.token2id: |
| | | raise RuntimeError( |
| | | f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list" |
| | | ) |
| | | self.unk_id = self.token2id[self.unk_symbol] |
| | | |
| | | def encode(self, text): |
| | | tokens = self.text2tokens(text) |