From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/tokenizer/abs_tokenizer.py |   58 ++++++++++++++++++++++++++--------------------------------
 1 files changed, 26 insertions(+), 32 deletions(-)

diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index d43d7b2..e125d29 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -1,15 +1,10 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Iterable
-from typing import List
-from pathlib import Path
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Union
 import json
-
 import numpy as np
+from abc import ABC
+from pathlib import Path
+from abc import abstractmethod
+from typing import Union, Iterable, List, Dict
+
 
 class AbsTokenizer(ABC):
     @abstractmethod
@@ -22,17 +17,19 @@
 
 
 class BaseTokenizer(ABC):
-    def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
-                 unk_symbol: str = "<unk>",
-                 **kwargs,
-                 ):
-        
+    def __init__(
+        self,
+        token_list: Union[Path, str, Iterable[str]] = None,
+        unk_symbol: str = "<unk>",
+        **kwargs,
+    ):
+
         if token_list is not None:
             if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
                 token_list = Path(token_list)
                 self.token_list_repr = str(token_list)
                 self.token_list: List[str] = []
-                
+
                 with token_list.open("r", encoding="utf-8") as f:
                     for idx, line in enumerate(f):
                         line = line.rstrip()
@@ -42,9 +39,8 @@
                 self.token_list_repr = str(token_list)
                 self.token_list: List[str] = []
 
-                with open(token_list, 'r', encoding='utf-8') as f:
+                with open(token_list, "r", encoding="utf-8") as f:
                     self.token_list = json.load(f)
-                    
 
             else:
                 self.token_list: List[str] = list(token_list)
@@ -54,46 +50,44 @@
                         break
                     self.token_list_repr += f"{t}, "
                 self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
-            
+
             self.token2id: Dict[str, int] = {}
             for i, t in enumerate(self.token_list):
                 if t in self.token2id:
                     raise RuntimeError(f'Symbol "{t}" is duplicated')
                 self.token2id[t] = i
-            
+
             self.unk_symbol = unk_symbol
             if self.unk_symbol not in self.token2id:
-                raise RuntimeError(
-                    f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
-                )
+                raise RuntimeError(f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list")
             self.unk_id = self.token2id[self.unk_symbol]
-    
-    def encode(self, text):
+
+    def encode(self, text, **kwargs):
         tokens = self.text2tokens(text)
         text_ints = self.tokens2ids(tokens)
-        
+
         return text_ints
-    
+
     def decode(self, text_ints):
         token = self.ids2tokens(text_ints)
         text = self.tokens2text(token)
         return text
-    
+
     def get_num_vocabulary_size(self) -> int:
         return len(self.token_list)
-    
+
     def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
         if isinstance(integers, np.ndarray) and integers.ndim != 1:
             raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
         return [self.token_list[i] for i in integers]
-    
+
     def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
         return [self.token2id.get(i, self.unk_id) for i in tokens]
-    
+
     @abstractmethod
     def text2tokens(self, line: str) -> List[str]:
         raise NotImplementedError
-    
+
     @abstractmethod
     def tokens2text(self, tokens: Iterable[str]) -> str:
         raise NotImplementedError

--
Gitblit v1.9.1