From 40eefbe376bd2e846bd9c451636f4e42ba1bf8bf Mon Sep 17 00:00:00 2001
From: zhuzizyf <42790740+zhuzizyf@users.noreply.github.com>
Date: 星期三, 12 四月 2023 17:18:19 +0800
Subject: [PATCH] Merge pull request #1 from zhuzizyf/fix-dataset-bug

---
 funasr/runtime/python/libtorch/funasr_torch/utils/utils.py |   13 +++++--------
 1 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py b/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
index cafc43b..86e78bc 100644
--- a/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
+++ b/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -23,9 +23,11 @@
                  ):
         check_argument_types()
 
-        # self.token_list = self.load_token(token_path)
         self.token_list = token_list
         self.unk_symbol = token_list[-1]
+        self.token2id = {v: i for i, v in enumerate(self.token_list)}
+        self.unk_id = self.token2id[self.unk_symbol]
+
 
     def get_num_vocabulary_size(self) -> int:
         return len(self.token_list)
@@ -38,13 +40,8 @@
         return [self.token_list[i] for i in integers]
 
     def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
-        token2id = {v: i for i, v in enumerate(self.token_list)}
-        if self.unk_symbol not in token2id:
-            raise TokenIDConverterError(
-                f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list"
-            )
-        unk_id = token2id[self.unk_symbol]
-        return [token2id.get(i, unk_id) for i in tokens]
+
+        return [self.token2id.get(i, self.unk_id) for i in tokens]
 
 
 class CharTokenizer():

--
Gitblit v1.9.1