From a05e753d11d9c36983ec4e58c421dbcf86d1dcd4 Mon Sep 17 00:00:00 2001
From: Xian Shi <40013335+R1ckShi@users.noreply.github.com>
Date: 星期二, 17 十月 2023 16:47:27 +0800
Subject: [PATCH] Merge branch 'main' into dev_onnx
---
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 26 ++++++++++++++++++++++----
1 files changed, 22 insertions(+), 4 deletions(-)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index c994036..71cf434 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -38,7 +38,13 @@
):
if not Path(model_dir).exists():
- from modelscope.hub.snapshot_download import snapshot_download
+ try:
+ from modelscope.hub.snapshot_download import snapshot_download
+ except:
+ raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
+ "\npip3 install -U modelscope\n" \
+ "For the users in China, you could install with the command:\n" \
+ "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
@@ -49,7 +55,13 @@
model_file = os.path.join(model_dir, 'model_quant.onnx')
if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx")
- from funasr.export.export_model import ModelExport
+ try:
+ from funasr.export.export_model import ModelExport
+ except:
+ raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
+ "\npip3 install -U funasr\n" \
+ "For the users in China, you could install with the command:\n" \
+ "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
export_model = ModelExport(
cache_dir=cache_dir,
onnx=True,
@@ -229,7 +241,6 @@
):
if not Path(model_dir).exists():
- from modelscope.hub.snapshot_download import snapshot_download
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
@@ -314,7 +325,14 @@
hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
# hotwords.append('<s>')
def word_map(word):
- return torch.tensor([self.vocab[i] for i in word])
+ hotwords = []
+ for c in word:
+ if c not in self.vocab.keys():
+ hotwords.append(8403)
+ logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
+ else:
+ hotwords.append(self.vocab[c])
+ return torch.tensor(hotwords)
hotword_int = [word_map(i) for i in hotwords]
# import pdb; pdb.set_trace()
hotword_int.append(torch.tensor([1]))
--
Gitblit v1.9.1