游雁
2023-12-06 27f31cd42bb4e20dc19de0034fc0d80b449f1db1
funasr/tokenizer/abs_tokenizer.py
@@ -2,7 +2,13 @@
from abc import abstractmethod
from typing import Iterable
from typing import List
from pathlib import Path
from typing import Dict
from typing import Iterable
from typing import List
from typing import Union
import numpy as np
class AbsTokenizer(ABC):
    @abstractmethod
@@ -12,3 +18,70 @@
    @abstractmethod
    def tokens2text(self, tokens: Iterable[str]) -> str:
        raise NotImplementedError
class BaseTokenizer(ABC):
    def __init__(self, token_list: Union[Path, str, Iterable[str]],
                 unk_symbol: str = "<unk>",
                 **kwargs,
                 ):
        if isinstance(token_list, (Path, str)):
            token_list = Path(token_list)
            self.token_list_repr = str(token_list)
            self.token_list: List[str] = []
            with token_list.open("r", encoding="utf-8") as f:
                for idx, line in enumerate(f):
                    line = line.rstrip()
                    self.token_list.append(line)
        else:
            self.token_list: List[str] = list(token_list)
            self.token_list_repr = ""
            for i, t in enumerate(self.token_list):
                if i == 3:
                    break
                self.token_list_repr += f"{t}, "
            self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
        self.token2id: Dict[str, int] = {}
        for i, t in enumerate(self.token_list):
            if t in self.token2id:
                raise RuntimeError(f'Symbol "{t}" is duplicated')
            self.token2id[t] = i
        self.unk_symbol = unk_symbol
        if self.unk_symbol not in self.token2id:
            raise RuntimeError(
                f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
            )
        self.unk_id = self.token2id[self.unk_symbol]
    def encode(self, text):
        tokens = self.text2tokens(text)
        text_ints = self.tokens2ids(tokens)
        return text_ints
    def decode(self, text_ints):
        return self.ids2tokens(text_ints)
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
        if isinstance(integers, np.ndarray) and integers.ndim != 1:
            raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
        return [self.token_list[i] for i in integers]
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
        return [self.token2id.get(i, self.unk_id) for i in tokens]
    @abstractmethod
    def text2tokens(self, line: str) -> List[str]:
        raise NotImplementedError
    @abstractmethod
    def tokens2text(self, tokens: Iterable[str]) -> str:
        raise NotImplementedError