游雁
2024-01-14 8912e0696af069de47646fdb8a9d9c4e086e88b3
funasr/tokenizer/abs_tokenizer.py
@@ -11,89 +11,90 @@
import numpy as np
class AbsTokenizer(ABC):
    @abstractmethod
    def text2tokens(self, line: str) -> List[str]:
        raise NotImplementedError
    @abstractmethod
    def tokens2text(self, tokens: Iterable[str]) -> str:
        raise NotImplementedError
class AbsTokenizer(ABC):
   @abstractmethod
   def text2tokens(self, line: str) -> List[str]:
      raise NotImplementedError
   @abstractmethod
   def tokens2text(self, tokens: Iterable[str]) -> str:
      raise NotImplementedError
class BaseTokenizer(ABC):
    def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
                 unk_symbol: str = "<unk>",
                 **kwargs,
                 ):
        if token_list is not None:
            if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
                token_list = Path(token_list)
                self.token_list_repr = str(token_list)
                self.token_list: List[str] = []
                with token_list.open("r", encoding="utf-8") as f:
                    for idx, line in enumerate(f):
                        line = line.rstrip()
                        self.token_list.append(line)
            elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
                token_list = Path(token_list)
                self.token_list_repr = str(token_list)
                self.token_list: List[str] = []
                with open(token_list, 'r', encoding='utf-8') as f:
                    self.token_list = json.load(f)
            else:
                self.token_list: List[str] = list(token_list)
                self.token_list_repr = ""
                for i, t in enumerate(self.token_list):
                    if i == 3:
                        break
                    self.token_list_repr += f"{t}, "
                self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
            self.token2id: Dict[str, int] = {}
            for i, t in enumerate(self.token_list):
                if t in self.token2id:
                    raise RuntimeError(f'Symbol "{t}" is duplicated')
                self.token2id[t] = i
            self.unk_symbol = unk_symbol
            if self.unk_symbol not in self.token2id:
                raise RuntimeError(
                    f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
                )
            self.unk_id = self.token2id[self.unk_symbol]
    def encode(self, text):
        tokens = self.text2tokens(text)
        text_ints = self.tokens2ids(tokens)
        return text_ints
    def decode(self, text_ints):
        token = self.ids2tokens(text_ints)
        text = self.tokens2text(token)
        return text
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
        if isinstance(integers, np.ndarray) and integers.ndim != 1:
            raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
        return [self.token_list[i] for i in integers]
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
        return [self.token2id.get(i, self.unk_id) for i in tokens]
    @abstractmethod
    def text2tokens(self, line: str) -> List[str]:
        raise NotImplementedError
    @abstractmethod
    def tokens2text(self, tokens: Iterable[str]) -> str:
        raise NotImplementedError
   def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
                unk_symbol: str = "<unk>",
                **kwargs,
                ):
      if token_list is not None:
         if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
            token_list = Path(token_list)
            self.token_list_repr = str(token_list)
            self.token_list: List[str] = []
            with token_list.open("r", encoding="utf-8") as f:
               for idx, line in enumerate(f):
                  line = line.rstrip()
                  self.token_list.append(line)
         elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
            token_list = Path(token_list)
            self.token_list_repr = str(token_list)
            self.token_list: List[str] = []
            with open(token_list, 'r', encoding='utf-8') as f:
               self.token_list = json.load(f)
         else:
            self.token_list: List[str] = list(token_list)
            self.token_list_repr = ""
            for i, t in enumerate(self.token_list):
               if i == 3:
                  break
               self.token_list_repr += f"{t}, "
            self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
         self.token2id: Dict[str, int] = {}
         for i, t in enumerate(self.token_list):
            if t in self.token2id:
               raise RuntimeError(f'Symbol "{t}" is duplicated')
            self.token2id[t] = i
         self.unk_symbol = unk_symbol
         if self.unk_symbol not in self.token2id:
            raise RuntimeError(
               f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
            )
         self.unk_id = self.token2id[self.unk_symbol]
   def encode(self, text):
      tokens = self.text2tokens(text)
      text_ints = self.tokens2ids(tokens)
      return text_ints
   def decode(self, text_ints):
      token = self.ids2tokens(text_ints)
      text = self.tokens2text(token)
      return text
   def get_num_vocabulary_size(self) -> int:
      return len(self.token_list)
   def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
      if isinstance(integers, np.ndarray) and integers.ndim != 1:
         raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
      return [self.token_list[i] for i in integers]
   def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
      return [self.token2id.get(i, self.unk_id) for i in tokens]
   @abstractmethod
   def text2tokens(self, line: str) -> List[str]:
      raise NotImplementedError
   @abstractmethod
   def tokens2text(self, tokens: Iterable[str]) -> str:
      raise NotImplementedError