From e30a17cf4e715b3d139fa1e0ba01cda1bcf0f884 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期三, 10 一月 2024 11:23:41 +0800
Subject: [PATCH] update funasr-onnx
---
funasr/tokenizer/abs_tokenizer.py | 17 ++++++++++++++---
1 files changed, 14 insertions(+), 3 deletions(-)
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index d2fc3f0..d43d7b2 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -7,6 +7,7 @@
from typing import Iterable
from typing import List
from typing import Union
+import json
import numpy as np
@@ -27,7 +28,7 @@
):
if token_list is not None:
- if isinstance(token_list, (Path, str)):
+ if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
token_list = Path(token_list)
self.token_list_repr = str(token_list)
self.token_list: List[str] = []
@@ -36,7 +37,15 @@
for idx, line in enumerate(f):
line = line.rstrip()
self.token_list.append(line)
-
+ elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
+ token_list = Path(token_list)
+ self.token_list_repr = str(token_list)
+ self.token_list: List[str] = []
+
+ with open(token_list, 'r', encoding='utf-8') as f:
+ self.token_list = json.load(f)
+
+
else:
self.token_list: List[str] = list(token_list)
self.token_list_repr = ""
@@ -66,7 +75,9 @@
return text_ints
def decode(self, text_ints):
- return self.ids2tokens(text_ints)
+ token = self.ids2tokens(text_ints)
+ text = self.tokens2text(token)
+ return text
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
--
Gitblit v1.9.1