| funasr/runtime/python/libtorch/funasr_torch/utils/utils.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -23,9 +23,11 @@ ): check_argument_types() # self.token_list = self.load_token(token_path) self.token_list = token_list self.unk_symbol = token_list[-1] self.token2id = {v: i for i, v in enumerate(self.token_list)} self.unk_id = self.token2id[self.unk_symbol] def get_num_vocabulary_size(self) -> int: return len(self.token_list) @@ -38,13 +40,8 @@ return [self.token_list[i] for i in integers] def tokens2ids(self, tokens: Iterable[str]) -> List[int]: token2id = {v: i for i, v in enumerate(self.token_list)} if self.unk_symbol not in token2id: raise TokenIDConverterError( f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list" ) unk_id = token2id[self.unk_symbol] return [token2id.get(i, unk_id) for i in tokens] return [self.token2id.get(i, self.unk_id) for i in tokens] class CharTokenizer(): funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -24,21 +24,11 @@ ): check_argument_types() # self.token_list = self.load_token(token_path) self.token_list = token_list self.unk_symbol = token_list[-1] self.token2id = {v: i for i, v in enumerate(self.token_list)} self.unk_id = self.token2id[self.unk_symbol] # @staticmethod # def load_token(file_path: Union[Path, str]) -> List: # if not Path(file_path).exists(): # raise TokenIDConverterError(f'The {file_path} does not exist.') # # with open(str(file_path), 'rb') as f: # token_list = pickle.load(f) # # if len(token_list) != len(set(token_list)): # raise TokenIDConverterError('The Token exists duplicated symbol.') # return token_list def get_num_vocabulary_size(self) -> int: return len(self.token_list) @@ -51,13 +41,8 @@ return [self.token_list[i] for i in integers] def tokens2ids(self, tokens: Iterable[str]) -> List[int]: token2id = {v: i for i, v in enumerate(self.token_list)} if self.unk_symbol not in token2id: raise TokenIDConverterError( f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list" ) unk_id = token2id[self.unk_symbol] return [token2id.get(i, unk_id) for i in tokens] return [self.token2id.get(i, self.unk_id) for i in tokens] class CharTokenizer():