From 27f31cd42bb4e20dc19de0034fc0d80b449f1db1 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 06 十二月 2023 17:01:12 +0800
Subject: [PATCH] funasr2

---
 funasr/tokenizer/abs_tokenizer.py |   73 ++++++++++++++++++++++++++++++++++++
 1 files changed, 73 insertions(+), 0 deletions(-)

diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index fc2ccb3..ffb6b76 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -2,7 +2,13 @@
 from abc import abstractmethod
 from typing import Iterable
 from typing import List
+from pathlib import Path
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Union
 
+import numpy as np
 
 class AbsTokenizer(ABC):
     @abstractmethod
@@ -12,3 +18,70 @@
     @abstractmethod
     def tokens2text(self, tokens: Iterable[str]) -> str:
         raise NotImplementedError
+
+
+class BaseTokenizer(ABC):
+    def __init__(self, token_list: Union[Path, str, Iterable[str]],
+                 unk_symbol: str = "<unk>",
+                 **kwargs,
+                 ):
+        
+        if isinstance(token_list, (Path, str)):
+            token_list = Path(token_list)
+            self.token_list_repr = str(token_list)
+            self.token_list: List[str] = []
+            
+            with token_list.open("r", encoding="utf-8") as f:
+                for idx, line in enumerate(f):
+                    line = line.rstrip()
+                    self.token_list.append(line)
+        
+        else:
+            self.token_list: List[str] = list(token_list)
+            self.token_list_repr = ""
+            for i, t in enumerate(self.token_list):
+                if i == 3:
+                    break
+                self.token_list_repr += f"{t}, "
+            self.token_list_repr += f"... (NVocab={(len(self.token_list))})"
+        
+        self.token2id: Dict[str, int] = {}
+        for i, t in enumerate(self.token_list):
+            if t in self.token2id:
+                raise RuntimeError(f'Symbol "{t}" is duplicated')
+            self.token2id[t] = i
+        
+        self.unk_symbol = unk_symbol
+        if self.unk_symbol not in self.token2id:
+            raise RuntimeError(
+                f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list"
+            )
+        self.unk_id = self.token2id[self.unk_symbol]
+    
+    def encode(self, text):
+        tokens = self.text2tokens(text)
+        text_ints = self.tokens2ids(tokens)
+        
+        return text_ints
+    
+    def decode(self, text_ints):
+        return self.ids2tokens(text_ints)
+    
+    def get_num_vocabulary_size(self) -> int:
+        return len(self.token_list)
+    
+    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
+        if isinstance(integers, np.ndarray) and integers.ndim != 1:
+            raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}")
+        return [self.token_list[i] for i in integers]
+    
+    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
+        return [self.token2id.get(i, self.unk_id) for i in tokens]
+    
+    @abstractmethod
+    def text2tokens(self, line: str) -> List[str]:
+        raise NotImplementedError
+    
+    @abstractmethod
+    def tokens2text(self, tokens: Iterable[str]) -> str:
+        raise NotImplementedError

--
Gitblit v1.9.1