From 8912e0696af069de47646fdb8a9d9c4e086e88b3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 14 一月 2024 23:42:11 +0800
Subject: [PATCH] Resolve merge conflict
---
funasr/tokenizer/abs_tokenizer.py | 165 +++++++++++++++++++++++++++---------------------------
1 files changed, 83 insertions(+), 82 deletions(-)
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index d43d7b2..548bf06 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -11,89 +11,90 @@
import numpy as np
-class AbsTokenizer(ABC):
- @abstractmethod
- def text2tokens(self, line: str) -> List[str]:
- raise NotImplementedError
- @abstractmethod
- def tokens2text(self, tokens: Iterable[str]) -> str:
- raise NotImplementedError
+class AbsTokenizer(ABC):
+ @abstractmethod
+ def text2tokens(self, line: str) -> List[str]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def tokens2text(self, tokens: Iterable[str]) -> str:
+ raise NotImplementedError
class BaseTokenizer(ABC):
- def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
- unk_symbol: str = "<unk>",
- **kwargs,
- ):
-
- if token_list is not None:
- if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
- token_list = Path(token_list)
- self.token_list_repr = str(token_list)
- self.token_list: List[str] = []
-
- with token_list.open("r", encoding="utf-8") as f:
- for idx, line in enumerate(f):
- line = line.rstrip()
- self.token_list.append(line)
- elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
- token_list = Path(token_list)
- self.token_list_repr = str(token_list)
- self.token_list: List[str] = []
-
- with open(token_list, 'r', encoding='utf-8') as f:
- self.token_list = json.load(f)
-
-
- else:
- self.token_list: List[str] = list(token_list)
- self.token_list_repr = ""
- for i, t in enumerate(self.token_list):
- if i == 3:
- break
- self.token_list_repr += f"{t}, "
- self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
-
- self.token2id: Dict[str, int] = {}
- for i, t in enumerate(self.token_list):
- if t in self.token2id:
- raise RuntimeError(f'Symbol "{t}" is duplicated')
- self.token2id[t] = i
-
- self.unk_symbol = unk_symbol
- if self.unk_symbol not in self.token2id:
- raise RuntimeError(
- f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
- )
- self.unk_id = self.token2id[self.unk_symbol]
-
- def encode(self, text):
- tokens = self.text2tokens(text)
- text_ints = self.tokens2ids(tokens)
-
- return text_ints
-
- def decode(self, text_ints):
- token = self.ids2tokens(text_ints)
- text = self.tokens2text(token)
- return text
-
- def get_num_vocabulary_size(self) -> int:
- return len(self.token_list)
-
- def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
- if isinstance(integers, np.ndarray) and integers.ndim != 1:
- raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
- return [self.token_list[i] for i in integers]
-
- def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
- return [self.token2id.get(i, self.unk_id) for i in tokens]
-
- @abstractmethod
- def text2tokens(self, line: str) -> List[str]:
- raise NotImplementedError
-
- @abstractmethod
- def tokens2text(self, tokens: Iterable[str]) -> str:
- raise NotImplementedError
+ def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
+ unk_symbol: str = "<unk>",
+ **kwargs,
+ ):
+
+ if token_list is not None:
+ if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
+ token_list = Path(token_list)
+ self.token_list_repr = str(token_list)
+ self.token_list: List[str] = []
+
+ with token_list.open("r", encoding="utf-8") as f:
+ for idx, line in enumerate(f):
+ line = line.rstrip()
+ self.token_list.append(line)
+ elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
+ token_list = Path(token_list)
+ self.token_list_repr = str(token_list)
+ self.token_list: List[str] = []
+
+ with open(token_list, 'r', encoding='utf-8') as f:
+ self.token_list = json.load(f)
+
+
+ else:
+ self.token_list: List[str] = list(token_list)
+ self.token_list_repr = ""
+ for i, t in enumerate(self.token_list):
+ if i == 3:
+ break
+ self.token_list_repr += f"{t}, "
+ self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
+
+ self.token2id: Dict[str, int] = {}
+ for i, t in enumerate(self.token_list):
+ if t in self.token2id:
+ raise RuntimeError(f'Symbol "{t}" is duplicated')
+ self.token2id[t] = i
+
+ self.unk_symbol = unk_symbol
+ if self.unk_symbol not in self.token2id:
+ raise RuntimeError(
+ f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
+ )
+ self.unk_id = self.token2id[self.unk_symbol]
+
+ def encode(self, text):
+ tokens = self.text2tokens(text)
+ text_ints = self.tokens2ids(tokens)
+
+ return text_ints
+
+ def decode(self, text_ints):
+ token = self.ids2tokens(text_ints)
+ text = self.tokens2text(token)
+ return text
+
+ def get_num_vocabulary_size(self) -> int:
+ return len(self.token_list)
+
+ def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
+ if isinstance(integers, np.ndarray) and integers.ndim != 1:
+ raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
+ return [self.token_list[i] for i in integers]
+
+ def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
+ return [self.token2id.get(i, self.unk_id) for i in tokens]
+
+ @abstractmethod
+ def text2tokens(self, line: str) -> List[str]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def tokens2text(self, tokens: Iterable[str]) -> str:
+ raise NotImplementedError
\ No newline at end of file
--
Gitblit v1.9.1