From 8912e0696af069de47646fdb8a9d9c4e086e88b3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 14 一月 2024 23:42:11 +0800
Subject: [PATCH] Resolve merge conflict

---
 funasr/tokenizer/abs_tokenizer.py |  165 +++++++++++++++++++++++++++---------------------------
 1 files changed, 83 insertions(+), 82 deletions(-)

diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index d43d7b2..548bf06 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -11,89 +11,90 @@
 
 import numpy as np
 
-class AbsTokenizer(ABC):
-    @abstractmethod
-    def text2tokens(self, line: str) -> List[str]:
-        raise NotImplementedError
 
-    @abstractmethod
-    def tokens2text(self, tokens: Iterable[str]) -> str:
-        raise NotImplementedError
+class AbsTokenizer(ABC):
+	@abstractmethod
+	def text2tokens(self, line: str) -> List[str]:
+		raise NotImplementedError
+	
+	@abstractmethod
+	def tokens2text(self, tokens: Iterable[str]) -> str:
+		raise NotImplementedError
 
 
 class BaseTokenizer(ABC):
-    def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
-                 unk_symbol: str = "<unk>",
-                 **kwargs,
-                 ):
-        
-        if token_list is not None:
-            if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
-                token_list = Path(token_list)
-                self.token_list_repr = str(token_list)
-                self.token_list: List[str] = []
-                
-                with token_list.open("r", encoding="utf-8") as f:
-                    for idx, line in enumerate(f):
-                        line = line.rstrip()
-                        self.token_list.append(line)
-            elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
-                token_list = Path(token_list)
-                self.token_list_repr = str(token_list)
-                self.token_list: List[str] = []
-
-                with open(token_list, 'r', encoding='utf-8') as f:
-                    self.token_list = json.load(f)
-                    
-
-            else:
-                self.token_list: List[str] = list(token_list)
-                self.token_list_repr = ""
-                for i, t in enumerate(self.token_list):
-                    if i == 3:
-                        break
-                    self.token_list_repr += f"{t}, "
-                self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
-            
-            self.token2id: Dict[str, int] = {}
-            for i, t in enumerate(self.token_list):
-                if t in self.token2id:
-                    raise RuntimeError(f'Symbol "{t}" is duplicated')
-                self.token2id[t] = i
-            
-            self.unk_symbol = unk_symbol
-            if self.unk_symbol not in self.token2id:
-                raise RuntimeError(
-                    f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
-                )
-            self.unk_id = self.token2id[self.unk_symbol]
-    
-    def encode(self, text):
-        tokens = self.text2tokens(text)
-        text_ints = self.tokens2ids(tokens)
-        
-        return text_ints
-    
-    def decode(self, text_ints):
-        token = self.ids2tokens(text_ints)
-        text = self.tokens2text(token)
-        return text
-    
-    def get_num_vocabulary_size(self) -> int:
-        return len(self.token_list)
-    
-    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
-        if isinstance(integers, np.ndarray) and integers.ndim != 1:
-            raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
-        return [self.token_list[i] for i in integers]
-    
-    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
-        return [self.token2id.get(i, self.unk_id) for i in tokens]
-    
-    @abstractmethod
-    def text2tokens(self, line: str) -> List[str]:
-        raise NotImplementedError
-    
-    @abstractmethod
-    def tokens2text(self, tokens: Iterable[str]) -> str:
-        raise NotImplementedError
+	def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
+	             unk_symbol: str = "<unk>",
+	             **kwargs,
+	             ):
+		
+		if token_list is not None:
+			if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
+				token_list = Path(token_list)
+				self.token_list_repr = str(token_list)
+				self.token_list: List[str] = []
+				
+				with token_list.open("r", encoding="utf-8") as f:
+					for idx, line in enumerate(f):
+						line = line.rstrip()
+						self.token_list.append(line)
+			elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
+				token_list = Path(token_list)
+				self.token_list_repr = str(token_list)
+				self.token_list: List[str] = []
+				
+				with open(token_list, 'r', encoding='utf-8') as f:
+					self.token_list = json.load(f)
+			
+			
+			else:
+				self.token_list: List[str] = list(token_list)
+				self.token_list_repr = ""
+				for i, t in enumerate(self.token_list):
+					if i == 3:
+						break
+					self.token_list_repr += f"{t}, "
+				self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
+			
+			self.token2id: Dict[str, int] = {}
+			for i, t in enumerate(self.token_list):
+				if t in self.token2id:
+					raise RuntimeError(f'Symbol "{t}" is duplicated')
+				self.token2id[t] = i
+			
+			self.unk_symbol = unk_symbol
+			if self.unk_symbol not in self.token2id:
+				raise RuntimeError(
+					f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
+				)
+			self.unk_id = self.token2id[self.unk_symbol]
+	
+	def encode(self, text):
+		tokens = self.text2tokens(text)
+		text_ints = self.tokens2ids(tokens)
+		
+		return text_ints
+	
+	def decode(self, text_ints):
+		token = self.ids2tokens(text_ints)
+		text = self.tokens2text(token)
+		return text
+	
+	def get_num_vocabulary_size(self) -> int:
+		return len(self.token_list)
+	
+	def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
+		if isinstance(integers, np.ndarray) and integers.ndim != 1:
+			raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
+		return [self.token_list[i] for i in integers]
+	
+	def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
+		return [self.token2id.get(i, self.unk_id) for i in tokens]
+	
+	@abstractmethod
+	def text2tokens(self, line: str) -> List[str]:
+		raise NotImplementedError
+	
+	@abstractmethod
+	def tokens2text(self, tokens: Iterable[str]) -> str:
+		raise NotImplementedError
\ No newline at end of file

--
Gitblit v1.9.1