From b7060884fa4b8b85f79462644a5c99062d223da0 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期二, 25 六月 2024 17:38:04 +0800
Subject: [PATCH] Merge Dev tclas (#1847)
---
funasr/download/runtime_sdk_download_tool.py | 12 +
runtime/websocket/bin/funasr-wss-server.cpp | 53 +++++---
examples/industrial_data_pretraining/llm_asr/demo_speech2text.sh | 1
runtime/onnxruntime/src/util.cpp | 9 +
runtime/onnxruntime/include/com-define.h | 7
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp | 2
runtime/python/libtorch/funasr_torch/paraformer_bin.py | 16 +-
examples/industrial_data_pretraining/paraformer/export.py | 2
examples/industrial_data_pretraining/bicif_paraformer/export.py | 2
runtime/onnxruntime/src/offline-stream.cpp | 35 ++---
runtime/onnxruntime/src/paraformer-torch.h | 1
funasr/utils/export_utils.py | 6
runtime/onnxruntime/src/paraformer-torch.cpp | 211 ++++++++++++++++++++++++++++++++---
runtime/python/libtorch/README.md | 2
14 files changed, 283 insertions(+), 76 deletions(-)
diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.py b/examples/industrial_data_pretraining/bicif_paraformer/export.py
index 44849b0..e4eb382 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/export.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.py
@@ -12,7 +12,7 @@
device="cpu",
)
-res = model.export(type="torchscripts", quantize=False)
+res = model.export(type="torchscript", quantize=False)
print(res)
diff --git a/examples/industrial_data_pretraining/llm_asr/demo_speech2text.sh b/examples/industrial_data_pretraining/llm_asr/demo_speech2text.sh
index 57299fc..3eba6d3 100644
--- a/examples/industrial_data_pretraining/llm_asr/demo_speech2text.sh
+++ b/examples/industrial_data_pretraining/llm_asr/demo_speech2text.sh
@@ -62,4 +62,3 @@
}&
done
-wait
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/export.py b/examples/industrial_data_pretraining/paraformer/export.py
index a91e9e4..6334e3b 100644
--- a/examples/industrial_data_pretraining/paraformer/export.py
+++ b/examples/industrial_data_pretraining/paraformer/export.py
@@ -13,7 +13,7 @@
model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
)
-res = model.export(type="torchscripts", quantize=False)
+res = model.export(type="torchscript", quantize=False)
# res = model.export(type="bladedisc", input=f"{model.model_path}/example/asr_example.wav")
print(res)
diff --git a/funasr/download/runtime_sdk_download_tool.py b/funasr/download/runtime_sdk_download_tool.py
index 96c6735..0db17c7 100644
--- a/funasr/download/runtime_sdk_download_tool.py
+++ b/funasr/download/runtime_sdk_download_tool.py
@@ -10,7 +10,7 @@
parser.add_argument("--model-name", type=str, required=True)
parser.add_argument("--export-dir", type=str, required=True)
parser.add_argument("--export", type=str2bool, default=True, help="whether to export model")
- parser.add_argument("--type", type=str, default="onnx", help='["onnx", "torch"]')
+ parser.add_argument("--type", type=str, default="onnx", help='["onnx", "torchscript", "bladedisc"]')
parser.add_argument("--device", type=str, default="cpu", help='["cpu", "cuda"]')
parser.add_argument("--quantize", type=str2bool, default=False, help="export quantized model")
parser.add_argument("--fallback-num", type=int, default=0, help="amp fallback number")
@@ -37,11 +37,17 @@
model_file = os.path.join(model_dir, "model.onnx")
if args.quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
+ if args.type == "torchscript":
+ model_file = os.path.join(model_dir, "model.torchscript")
+ args.device = "cuda"
+ elif args.type == "bladedisc":
+ model_file = os.path.join(model_dir, "model_blade.torchscript")
+ args.device = "cuda"
if not os.path.exists(model_file):
- print(".onnx is not exist, begin to export onnx")
+ print("model is not exist, begin to export " + model_file)
from funasr import AutoModel
- export_model = AutoModel(model=args.model_name, output_dir=output_dir)
+ export_model = AutoModel(model=args.model_name, output_dir=output_dir, device=args.device)
export_model.export(
quantize=args.quantize,
type=args.type,
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index 72b150f..a6d0798 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -23,7 +23,7 @@
export_dir=export_dir,
**kwargs,
)
- elif type == "torchscripts":
+ elif type == "torchscript":
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Exporting torchscripts on device {}".format(device))
_torchscripts(m, path=export_dir, device=device)
@@ -100,7 +100,7 @@
dummy_input = tuple([i.cuda() for i in dummy_input])
model_script = torch.jit.trace(model, dummy_input)
- model_script.save(os.path.join(path, f"{model.export_name}.torchscripts"))
+ model_script.save(os.path.join(path, f"{model.export_name}.torchscript"))
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
@@ -193,4 +193,4 @@
model.encoder = _bladedisc_opt(model.encoder, input_data[:2])
model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs))
model_script = torch.jit.trace(model, input_data)
- model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscripts"))
+ model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index c252bc7..d8d9473 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -52,7 +52,7 @@
std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_handle, nn_hotwords_);
// warm up
- for (size_t i = 0; i < 10; i++)
+ for (size_t i = 0; i < 1; i++)
{
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs, true, decoder_handle);
if(result){
diff --git a/runtime/onnxruntime/include/com-define.h b/runtime/onnxruntime/include/com-define.h
index 77a7b02..5f71f7b 100644
--- a/runtime/onnxruntime/include/com-define.h
+++ b/runtime/onnxruntime/include/com-define.h
@@ -48,6 +48,7 @@
#define MODEL_NAME "model.onnx"
// hotword embedding compile model
#define MODEL_EB_NAME "model_eb.onnx"
+#define TORCH_MODEL_EB_NAME "model_eb.torchscript"
#define QUANT_MODEL_NAME "model_quant.onnx"
#define VAD_CMVN_NAME "am.mvn"
#define VAD_CONFIG_NAME "config.yaml"
@@ -55,9 +56,9 @@
// gpu models
#define INFER_GPU "gpu"
#define BATCHSIZE "batch-size"
-#define TORCH_MODEL_NAME "model.torchscripts"
-#define TORCH_QUANT_MODEL_NAME "model_quant.torchscripts"
-#define BLADE_MODEL_NAME "model.blade.fp16.pt"
+#define TORCH_MODEL_NAME "model.torchscript"
+#define TORCH_QUANT_MODEL_NAME "model_quant.torchscript"
+#define BLADE_MODEL_NAME "model_blade.torchscript"
#define BLADEDISC "bladedisc"
#define AM_CMVN_NAME "am.mvn"
diff --git a/runtime/onnxruntime/src/offline-stream.cpp b/runtime/onnxruntime/src/offline-stream.cpp
index 35eb1ba..166d3c9 100644
--- a/runtime/onnxruntime/src/offline-stream.cpp
+++ b/runtime/onnxruntime/src/offline-stream.cpp
@@ -33,7 +33,8 @@
string am_cmvn_path;
string am_config_path;
string token_path;
- string hw_compile_model_path;
+ string hw_cpu_model_path;
+ string hw_gpu_model_path;
string seg_dict_path;
if(use_gpu){
@@ -50,33 +51,31 @@
}
bool enable_hotword = false;
- hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
+ hw_cpu_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
+ hw_gpu_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_MODEL_EB_NAME);
seg_dict_path = PathAppend(model_path.at(MODEL_DIR), MODEL_SEG_DICT);
- if (access(hw_compile_model_path.c_str(), F_OK) == 0) { // if model_eb.onnx exist, hotword enabled
+ if (access(hw_cpu_model_path.c_str(), F_OK) == 0) { // if model_eb.onnx exist, hotword enabled
enable_hotword = true;
- asr_handle->InitHwCompiler(hw_compile_model_path, thread_num);
+ asr_handle->InitHwCompiler(hw_cpu_model_path, thread_num);
asr_handle->InitSegDict(seg_dict_path);
}
- if (enable_hotword) {
- am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
- if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+ if (use_gpu && access(hw_gpu_model_path.c_str(), F_OK) == 0) { // if model_eb.torchscript exist, hotword enabled
+ enable_hotword = true;
+ asr_handle->InitHwCompiler(hw_gpu_model_path, thread_num);
+ asr_handle->InitSegDict(seg_dict_path);
+ }
+
+ am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
+ if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
- }
- } else {
- am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
- if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
- am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
- }
- if(use_gpu){
+ }
+ if(use_gpu){
am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_MODEL_NAME);
- if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
- am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_QUANT_MODEL_NAME);
- }
if(model_path.find(BLADEDISC) != model_path.end() && model_path.at(BLADEDISC) == "true"){
am_model_path = PathAppend(model_path.at(MODEL_DIR), BLADE_MODEL_NAME);
}
- }
}
+
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp
index a5f7194..466d80a 100644
--- a/runtime/onnxruntime/src/paraformer-torch.cpp
+++ b/runtime/onnxruntime/src/paraformer-torch.cpp
@@ -50,6 +50,11 @@
torch::jit::script::Module model = torch::jit::load(am_model, device);
model_ = std::make_shared<TorchModule>(std::move(model));
LOG(INFO) << "Successfully load model from " << am_model;
+ torch::NoGradGuard no_grad;
+ model_->eval();
+ torch::jit::setGraphExecutorOptimize(false);
+ torch::jit::FusionStrategy static0 = {{torch::jit::FusionBehavior::STATIC, 0}};
+ torch::jit::setFusionStrategy(static0);
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am model: " << am_model << e.what();
exit(-1);
@@ -100,6 +105,27 @@
void ParaformerTorch::InitHwCompiler(const std::string &hw_model, int thread_num) {
// TODO
+ torch::DeviceType device = at::kCPU;
+ #ifdef USE_GPU
+ if (!torch::cuda::is_available()) {
+ // LOG(ERROR) << "CUDA is not available! Please check your GPU settings";
+ exit(-1);
+ } else {
+ // LOG(INFO) << "CUDA is available, running on GPU";
+ device = at::kCUDA;
+ }
+ #endif
+
+ try {
+ torch::jit::script::Module model = torch::jit::load(hw_model, device);
+ hw_model_ = std::make_shared<TorchModule>(std::move(model));
+ LOG(INFO) << "Successfully load model from " << hw_model;
+ torch::NoGradGuard no_grad;
+ hw_model_->eval();
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load hw model: " << hw_model << e.what();
+ exit(-1);
+ }
use_hotword = true;
}
@@ -111,15 +137,19 @@
{
if(vocab){
delete vocab;
+ vocab = nullptr;
}
if(lm_vocab){
delete lm_vocab;
+ lm_vocab = nullptr;
}
if(seg_dict){
delete seg_dict;
+ seg_dict = nullptr;
}
if(phone_set_){
delete phone_set_;
+ phone_set_ = nullptr;
}
}
@@ -267,6 +297,9 @@
std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
{
+ vector<std::string> results;
+ string result="";
+
WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
int32_t feature_dim = lfr_m*in_feat_dim;
@@ -295,8 +328,13 @@
feats_batch.emplace_back(flattened);
}
- torch::NoGradGuard no_grad;
- model_->eval();
+ if(max_frames == 0){
+ for(int index=0; index<batch_in; index++){
+ results.push_back(result);
+ }
+ return results;
+ }
+
// padding
std::vector<float> all_feats(batch_in * max_frames * feature_dim);
for(int index=0; index<batch_in; index++){
@@ -317,8 +355,52 @@
#endif
std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
- vector<std::string> results;
+ std::vector<float> batch_embedding;
+ std::vector<float> embedding;
+ try{
+ if (use_hotword) {
+ if(hw_emb.size()<=0){
+ LOG(ERROR) << "hw_emb is null";
+ for(int index=0; index<batch_in; index++){
+ results.push_back(result);
+ }
+ return results;
+ }
+
+ embedding.reserve(hw_emb.size() * hw_emb[0].size());
+ for (auto item : hw_emb) {
+ embedding.insert(embedding.end(), item.begin(), item.end());
+ }
+ batch_embedding.reserve(batch_in * embedding.size());
+ for (size_t index = 0; index < batch_in; ++index) {
+ batch_embedding.insert(batch_embedding.end(), embedding.begin(), embedding.end());
+ }
+
+ torch::Tensor tensor_hw_emb =
+ torch::from_blob(batch_embedding.data(),
+ {batch_in, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())}, torch::kFloat).contiguous();
+ #ifdef USE_GPU
+ tensor_hw_emb = tensor_hw_emb.to(at::kCUDA);
+ #endif
+ inputs.emplace_back(tensor_hw_emb);
+ }
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ for(int index=0; index<batch_in; index++){
+ results.push_back(result);
+ }
+ return results;
+ }
+
try {
+ if(inputs.size() == 0){
+ LOG(ERROR) << "inputs of forward is null";
+ for(int index=0; index<batch_in; index++){
+ results.push_back(result);
+ }
+ return results;
+ }
auto outputs = model_->forward(inputs).toTuple()->elements();
torch::Tensor am_scores;
torch::Tensor valid_token_lens;
@@ -329,28 +411,31 @@
am_scores = outputs[0].toTensor();
valid_token_lens = outputs[1].toTensor();
#endif
+
+ torch::Tensor us_alphas_tensor;
+ torch::Tensor us_peaks_tensor;
+ if(outputs.size() == 4){
+ #ifdef USE_GPU
+ us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
+ us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
+ #else
+ us_alphas_tensor = outputs[2].toTensor();
+ us_peaks_tensor = outputs[3].toTensor();
+ #endif
+ }
+
// timestamp
for(int index=0; index<batch_in; index++){
- string result="";
+ result="";
if(outputs.size() == 4){
- torch::Tensor us_alphas_tensor;
- torch::Tensor us_peaks_tensor;
- #ifdef USE_GPU
- us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
- us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
- #else
- us_alphas_tensor = outputs[2].toTensor();
- us_peaks_tensor = outputs[3].toTensor();
- #endif
-
float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
- std::vector<float> us_alphas(paraformer_length[index]);
+ std::vector<float> us_alphas(paraformer_length[index]*3);
for (int i = 0; i < us_alphas.size(); i++) {
us_alphas[i] = us_alphas_data[i];
}
float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
- std::vector<float> us_peaks(paraformer_length[index]);
+ std::vector<float> us_peaks(paraformer_length[index]*3);
for (int i = 0; i < us_peaks.size(); i++) {
us_peaks[i] = us_peaks_data[i];
}
@@ -387,8 +472,98 @@
}
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
- // TODO
- std::vector<std::vector<float>> result(1, std::vector<float>(10, 0.0f));
+ int embedding_dim = encoder_size;
+ std::vector<std::vector<float>> hw_emb;
+ if (!use_hotword) {
+ std::vector<float> vec(embedding_dim, 0);
+ hw_emb.push_back(vec);
+ return hw_emb;
+ }
+ int max_hotword_len = 10;
+ std::vector<int32_t> hotword_matrix;
+ std::vector<int32_t> lengths;
+ int hotword_size = 1;
+ int real_hw_size = 0;
+ if (!hotwords.empty()) {
+ std::vector<std::string> hotword_array = split(hotwords, ' ');
+ hotword_size = hotword_array.size() + 1;
+ hotword_matrix.reserve(hotword_size * max_hotword_len);
+ for (auto hotword : hotword_array) {
+ std::vector<std::string> chars;
+ if (EncodeConverter::IsAllChineseCharactor((const U8CHAR_T*)hotword.c_str(), hotword.size())) {
+ KeepChineseCharacterAndSplit(hotword, chars);
+ } else {
+ // for english
+ std::vector<std::string> words = split(hotword, ' ');
+ for (auto word : words) {
+ std::vector<string> tokens = seg_dict->GetTokensByWord(word);
+ chars.insert(chars.end(), tokens.begin(), tokens.end());
+ }
+ }
+ if(chars.size()==0){
+ continue;
+ }
+ std::vector<int32_t> hw_vector(max_hotword_len, 0);
+ int vector_len = std::min(max_hotword_len, (int)chars.size());
+ int chs_oov = false;
+ for (int i=0; i<vector_len; i++) {
+ hw_vector[i] = phone_set_->String2Id(chars[i]);
+ if(hw_vector[i] == -1){
+ chs_oov = true;
+ break;
+ }
+ }
+ if(chs_oov){
+ LOG(INFO) << "OOV: " << hotword;
+ continue;
+ }
+ LOG(INFO) << hotword;
+ lengths.push_back(vector_len);
+ real_hw_size += 1;
+ hotword_matrix.insert(hotword_matrix.end(), hw_vector.begin(), hw_vector.end());
+ }
+ hotword_size = real_hw_size + 1;
+ }
+ std::vector<int32_t> blank_vec(max_hotword_len, 0);
+ blank_vec[0] = 1;
+ hotword_matrix.insert(hotword_matrix.end(), blank_vec.begin(), blank_vec.end());
+ lengths.push_back(1);
+
+ torch::Tensor feats =
+ torch::from_blob(hotword_matrix.data(),
+ {hotword_size, max_hotword_len}, torch::kInt32).contiguous();
+
+ // 2. forward
+ #ifdef USE_GPU
+ feats = feats.to(at::kCUDA);
+ #endif
+ std::vector<torch::jit::IValue> inputs = {feats};
+ std::vector<std::vector<float>> result;
+ try {
+ auto output = hw_model_->forward(inputs);
+ torch::Tensor emb_tensor;
+ #ifdef USE_GPU
+ emb_tensor = output.toTensor().to(at::kCPU);
+ #else
+ emb_tensor = output.toTensor();
+ #endif
+ assert(emb_tensor.size(0) == max_hotword_len);
+ assert(emb_tensor.size(1) == hotword_size);
+ embedding_dim = emb_tensor.size(2);
+
+ float* floatData = emb_tensor.data_ptr<float>();
+ for (int j = 0; j < hotword_size; j++)
+ {
+ int start_pos = hotword_size * (lengths[j] - 1) * embedding_dim + j * embedding_dim;
+ std::vector<float> embedding;
+ embedding.insert(embedding.begin(), floatData + start_pos, floatData + start_pos + embedding_dim);
+ result.push_back(embedding);
+ }
+ }
+ catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ }
return result;
}
diff --git a/runtime/onnxruntime/src/paraformer-torch.h b/runtime/onnxruntime/src/paraformer-torch.h
index 74ac315..bea33db 100644
--- a/runtime/onnxruntime/src/paraformer-torch.h
+++ b/runtime/onnxruntime/src/paraformer-torch.h
@@ -36,6 +36,7 @@
using TorchModule = torch::jit::script::Module;
std::shared_ptr<TorchModule> model_ = nullptr;
+ std::shared_ptr<TorchModule> hw_model_ = nullptr;
std::vector<torch::Tensor> encoder_outs_;
bool use_hotword;
diff --git a/runtime/onnxruntime/src/util.cpp b/runtime/onnxruntime/src/util.cpp
index a12570b..483795e 100644
--- a/runtime/onnxruntime/src/util.cpp
+++ b/runtime/onnxruntime/src/util.cpp
@@ -870,6 +870,15 @@
sum -=(1.0 - 1e-4);
}
}
+ // fix case: sum > 1
+ int cif_idx = cif_peak.size()-1;
+ while(sum>=1.0 - 1e-4 && cif_idx >= 0 ){
+ if(cif_peak[cif_idx] < 1.0 - 1e-4){
+ cif_peak[cif_idx] = sum;
+ sum -=(1.0 - 1e-4);
+ }
+ cif_idx--;
+ }
fire_place.clear();
for (int i = 0; i < num_frames; i++) {
diff --git a/runtime/python/libtorch/README.md b/runtime/python/libtorch/README.md
index a96846e..1d15d2b 100644
--- a/runtime/python/libtorch/README.md
+++ b/runtime/python/libtorch/README.md
@@ -41,7 +41,7 @@
## Run the demo
-- Model_dir: the model path, which contains `model.torchscripts`, `config.yaml`, `am.mvn`.
+- Model_dir: the model path, which contains `model.torchscript`, `config.yaml`, `am.mvn`.
- Input: wav formt file, support formats: `str, np.ndarray, List[str]`
- Output: `List[str]`: recognition result.
- Example:
diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index 5fa3cc9..16c0406 100644
--- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -46,11 +46,11 @@
model_dir
)
- model_file = os.path.join(model_dir, "model.torchscripts")
+ model_file = os.path.join(model_dir, "model.torchscript")
if quantize:
- model_file = os.path.join(model_dir, "model_quant.torchscripts")
+ model_file = os.path.join(model_dir, "model_quant.torchscript")
if not os.path.exists(model_file):
- print(".torchscripts does not exist, begin to export torchscripts")
+ print(".torchscripts does not exist, begin to export torchscript")
try:
from funasr import AutoModel
except:
@@ -268,11 +268,11 @@
)
if quantize:
- model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscripts")
- model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscripts")
+ model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscript")
+ model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscript")
else:
- model_bb_file = os.path.join(model_dir, "model_bb.torchscripts")
- model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
+ model_bb_file = os.path.join(model_dir, "model_bb.torchscript")
+ model_eb_file = os.path.join(model_dir, "model_eb.torchscript")
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
print(".onnx does not exist, begin to export onnx")
@@ -282,7 +282,7 @@
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
- model_dir = model.export(type="torchscripts", quantize=quantize, **kwargs)
+ model_dir = model.export(type="torchscript", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
diff --git a/runtime/websocket/bin/funasr-wss-server.cpp b/runtime/websocket/bin/funasr-wss-server.cpp
index 0d475da..3c5b81c 100644
--- a/runtime/websocket/bin/funasr-wss-server.cpp
+++ b/runtime/websocket/bin/funasr-wss-server.cpp
@@ -45,7 +45,7 @@
false, "/workspace/models", "string");
TCLAP::ValueArg<std::string> model_dir(
"", MODEL_DIR,
- "default: /workspace/models/asr, the asr model path, which contains model_quant.onnx, config.yaml, am.mvn",
+ "default: /workspace/models/asr, the asr model path, which contains *.onnx/*.torchscript, config.yaml, am.mvn",
false, "/workspace/models/asr", "string");
TCLAP::ValueArg<std::string> model_revision(
"", "model-revision",
@@ -67,7 +67,7 @@
TCLAP::ValueArg<std::string> vad_revision(
"", "vad-revision",
"VAD model revision",
- false, "v2.0.4", "string");
+ false, "v2.0.6", "string");
TCLAP::ValueArg<std::string> vad_quant(
"", VAD_QUANT,
"true (Default), load the model of model_quant.onnx in vad_dir. If set "
@@ -198,8 +198,9 @@
std::string s_punc_quant = model_path[PUNC_QUANT];
std::string s_itn_path = model_path[ITN_DIR];
std::string s_lm_path = model_path[LM_DIR];
+ std::string s_blade = model_path[BLADEDISC];
- std::string python_cmd = "python -m funasr.download.runtime_sdk_download_tool --type onnx ";
+ std::string python_cmd = "python -m funasr.download.runtime_sdk_download_tool ";
if(vad_dir.isSet() && !s_vad_path.empty()){
std::string python_cmd_vad;
@@ -208,12 +209,12 @@
if (access(s_vad_path.c_str(), F_OK) == 0){
// local
- python_cmd_vad = python_cmd + " --quantize " + s_vad_quant + " --model-name " + s_vad_path + " --export-dir ./ " + " --model_revision " + model_path["vad-revision"];
+ python_cmd_vad = python_cmd + " --type onnx " + " --quantize " + s_vad_quant + " --model-name " + s_vad_path + " --export-dir ./ " + " --model_revision " + model_path["vad-revision"];
down_vad_path = s_vad_path;
}else{
// modelscope
LOG(INFO) << "Download model: " << s_vad_path << " from modelscope: ";
- python_cmd_vad = python_cmd + " --quantize " + s_vad_quant + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["vad-revision"];
+ python_cmd_vad = python_cmd + " --type onnx " + " --quantize " + s_vad_quant + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["vad-revision"];
down_vad_path = s_download_model_dir+"/"+s_vad_path;
}
@@ -241,6 +242,7 @@
std::string python_cmd_asr;
std::string down_asr_path;
std::string down_asr_model;
+ std::string model_type = "onnx";
// modify model-revision by model name
size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
@@ -260,24 +262,39 @@
s_lm_path="";
}
+ if (use_gpu_){
+ model_type = "torchscript";
+ if (s_blade=="true" || s_blade=="True" || s_blade=="TRUE"){
+ model_type = "bladedisc";
+ }
+ }
+
if (access(s_asr_path.c_str(), F_OK) == 0){
// local
- python_cmd_asr = python_cmd + " --quantize " + s_asr_quant + " --model-name " + s_asr_path + " --export-dir ./ " + " --model_revision " + model_path["model-revision"];
+ python_cmd_asr = python_cmd + " --type " + model_type + " --quantize " + s_asr_quant + " --model-name " + s_asr_path + " --export-dir ./ " + " --model_revision " + model_path["model-revision"];
down_asr_path = s_asr_path;
}else{
// modelscope
LOG(INFO) << "Download model: " << s_asr_path << " from modelscope: ";
- python_cmd_asr = python_cmd + " --quantize " + s_asr_quant + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["model-revision"];
+ python_cmd_asr = python_cmd + " --type " + model_type + " --quantize " + s_asr_quant + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["model-revision"];
down_asr_path = s_download_model_dir+"/"+s_asr_path;
}
-
- int ret = system(python_cmd_asr.c_str());
- if(ret !=0){
- LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
- }
+
down_asr_model = down_asr_path+"/model_quant.onnx";
if(s_asr_quant=="false" || s_asr_quant=="False" || s_asr_quant=="FALSE"){
down_asr_model = down_asr_path+"/model.onnx";
+ }
+
+ if (use_gpu_){
+ down_asr_model = down_asr_path+"/model.torchscript";
+ if (s_blade=="true" || s_blade=="True" || s_blade=="TRUE"){
+ down_asr_model = down_asr_path+"/model_blade.torchscript";
+ }
+ }
+
+ int ret = system(python_cmd_asr.c_str());
+ if(ret !=0){
+ LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
}
if (access(down_asr_model.c_str(), F_OK) != 0){
@@ -298,7 +315,7 @@
if (access(s_itn_path.c_str(), F_OK) == 0) {
// local
- python_cmd_itn = python_cmd + " --model-name " + s_itn_path +
+ python_cmd_itn = python_cmd + " --type onnx " + " --model-name " + s_itn_path +
" --export-dir ./ " + " --model_revision " +
model_path["itn-revision"] + " --export False ";
down_itn_path = s_itn_path;
@@ -306,7 +323,7 @@
// modelscope
LOG(INFO) << "Download model: " << s_itn_path
<< " from modelscope : ";
- python_cmd_itn = python_cmd + " --model-name " +
+ python_cmd_itn = python_cmd + " --type onnx " + " --model-name " +
s_itn_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["itn-revision"]
@@ -340,7 +357,7 @@
if (access(s_lm_path.c_str(), F_OK) == 0) {
// local
- python_cmd_lm = python_cmd + "--quantize " + s_punc_quant + " --model-name " + s_lm_path +
+ python_cmd_lm = python_cmd + " --type onnx " + " --model-name " + s_lm_path +
" --export-dir ./ " + " --model_revision " +
model_path["lm-revision"] + " --export False ";
down_lm_path = s_lm_path;
@@ -348,7 +365,7 @@
// modelscope
LOG(INFO) << "Download model: " << s_lm_path
<< " from modelscope : ";
- python_cmd_lm = python_cmd + " --quantize " + s_punc_quant + " --model-name " +
+ python_cmd_lm = python_cmd + " --type onnx " + " --model-name " +
s_lm_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["lm-revision"]
@@ -383,12 +400,12 @@
if (access(s_punc_path.c_str(), F_OK) == 0){
// local
- python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir ./ " + " --model_revision " + model_path["punc-revision"];
+ python_cmd_punc = python_cmd + " --type onnx " + "--quantize " + s_punc_quant + " --model-name " + s_punc_path + " --export-dir ./ " + " --model_revision " + model_path["punc-revision"];
down_punc_path = s_punc_path;
}else{
// modelscope
LOG(INFO) << "Download model: " << s_punc_path << " from modelscope: ";
- python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["punc-revision"];
+ python_cmd_punc = python_cmd + " --type onnx " + "--quantize " + s_punc_quant + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["punc-revision"];
down_punc_path = s_download_model_dir+"/"+s_punc_path;
}
--
Gitblit v1.9.1