From 94de39dde2e616a01683c518023d0fab72b4e103 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 19 二月 2024 22:21:50 +0800
Subject: [PATCH] aishell example
---
funasr/tokenizer/abs_tokenizer.py | 36 +++++++++++++++++++++---------------
1 files changed, 21 insertions(+), 15 deletions(-)
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index d2fc3f0..136be13 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -1,33 +1,29 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Iterable
-from typing import List
-from pathlib import Path
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Union
-
+import json
import numpy as np
+from abc import ABC
+from pathlib import Path
+from abc import abstractmethod
+from typing import Union, Iterable, List, Dict
+
class AbsTokenizer(ABC):
@abstractmethod
def text2tokens(self, line: str) -> List[str]:
raise NotImplementedError
-
+
@abstractmethod
def tokens2text(self, tokens: Iterable[str]) -> str:
raise NotImplementedError
class BaseTokenizer(ABC):
- def __init__(self, token_list: Union[Path, str, Iterable[str]]=None,
+ def __init__(self, token_list: Union[Path, str, Iterable[str]] = None,
unk_symbol: str = "<unk>",
**kwargs,
):
if token_list is not None:
- if isinstance(token_list, (Path, str)):
+ if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
token_list = Path(token_list)
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
@@ -36,6 +32,14 @@
for idx, line in enumerate(f):
line = line.rstrip()
self.token_list.append(line)
+ elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
+ token_list = Path(token_list)
+ self.token_list_repr = str(token_list)
+ self.token_list: List[str] = []
+
+ with open(token_list, 'r', encoding='utf-8') as f:
+ self.token_list = json.load(f)
+
else:
self.token_list: List[str] = list(token_list)
@@ -66,7 +70,9 @@
return text_ints
def decode(self, text_ints):
- return self.ids2tokens(text_ints)
+ token = self.ids2tokens(text_ints)
+ text = self.tokens2text(token)
+ return text
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
@@ -85,4 +91,4 @@
@abstractmethod
def tokens2text(self, tokens: Iterable[str]) -> str:
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
--
Gitblit v1.9.1