游雁
2023-12-07 fc246ab820cf57ba08afbe3cbeb4d471036eb83c
funasr2
2个文件已修改
63 ■■■■ 已修改文件
funasr/tokenizer/abs_tokenizer.py 61 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/build_tokenizer.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tokenizer/abs_tokenizer.py
@@ -21,42 +21,43 @@
class BaseTokenizer(ABC):
    def __init__(self, token_list: Union[Path, str, Iterable[str]],
    def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
                 unk_symbol: str = "<unk>",
                 **kwargs,
                 ):
        
        if isinstance(token_list, (Path, str)):
            token_list = Path(token_list)
            self.token_list_repr = str(token_list)
            self.token_list: List[str] = []
        if token_list is not None:
            if isinstance(token_list, (Path, str)):
                token_list = Path(token_list)
                self.token_list_repr = str(token_list)
                self.token_list: List[str] = []
                with token_list.open("r", encoding="utf-8") as f:
                    for idx, line in enumerate(f):
                        line = line.rstrip()
                        self.token_list.append(line)
            
            with token_list.open("r", encoding="utf-8") as f:
                for idx, line in enumerate(f):
                    line = line.rstrip()
                    self.token_list.append(line)
        else:
            self.token_list: List[str] = list(token_list)
            self.token_list_repr = ""
            else:
                self.token_list: List[str] = list(token_list)
                self.token_list_repr = ""
                for i, t in enumerate(self.token_list):
                    if i == 3:
                        break
                    self.token_list_repr += f"{t}, "
                self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
            self.token2id: Dict[str, int] = {}
            for i, t in enumerate(self.token_list):
                if i == 3:
                    break
                self.token_list_repr += f"{t}, "
            self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
        self.token2id: Dict[str, int] = {}
        for i, t in enumerate(self.token_list):
            if t in self.token2id:
                raise RuntimeError(f'Symbol "{t}" is duplicated')
            self.token2id[t] = i
        self.unk_symbol = unk_symbol
        if self.unk_symbol not in self.token2id:
            raise RuntimeError(
                f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
            )
        self.unk_id = self.token2id[self.unk_symbol]
                if t in self.token2id:
                    raise RuntimeError(f'Symbol "{t}" is duplicated')
                self.token2id[t] = i
            self.unk_symbol = unk_symbol
            if self.unk_symbol not in self.token2id:
                raise RuntimeError(
                    f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
                )
            self.unk_id = self.token2id[self.unk_symbol]
    
    def encode(self, text):
        tokens = self.text2tokens(text)
funasr/tokenizer/build_tokenizer.py
@@ -29,7 +29,7 @@
    delimiter: str = None,
    g2p_type: str = None,
    **kwargs,
) -> AbsTokenizer:
):
    """A helper function to instantiate Tokenizer"""
    if token_type == "bpe":
        if bpemodel is None: