From b7060884fa4b8b85f79462644a5c99062d223da0 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期二, 25 六月 2024 17:38:04 +0800
Subject: [PATCH] Merge Dev tclas (#1847)

---
 funasr/download/runtime_sdk_download_tool.py                     |   12 +
 runtime/websocket/bin/funasr-wss-server.cpp                      |   53 +++++---
 examples/industrial_data_pretraining/llm_asr/demo_speech2text.sh |    1 
 runtime/onnxruntime/src/util.cpp                                 |    9 +
 runtime/onnxruntime/include/com-define.h                         |    7 
 runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp              |    2 
 runtime/python/libtorch/funasr_torch/paraformer_bin.py           |   16 +-
 examples/industrial_data_pretraining/paraformer/export.py        |    2 
 examples/industrial_data_pretraining/bicif_paraformer/export.py  |    2 
 runtime/onnxruntime/src/offline-stream.cpp                       |   35 ++---
 runtime/onnxruntime/src/paraformer-torch.h                       |    1 
 funasr/utils/export_utils.py                                     |    6 
 runtime/onnxruntime/src/paraformer-torch.cpp                     |  211 ++++++++++++++++++++++++++++++++---
 runtime/python/libtorch/README.md                                |    2 
 14 files changed, 283 insertions(+), 76 deletions(-)

diff --git a/examples/industrial_data_pretraining/bicif_paraformer/export.py b/examples/industrial_data_pretraining/bicif_paraformer/export.py
index 44849b0..e4eb382 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/export.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/export.py
@@ -12,7 +12,7 @@
     device="cpu",
 )
 
-res = model.export(type="torchscripts", quantize=False)
+res = model.export(type="torchscript", quantize=False)
 print(res)
 
 
diff --git a/examples/industrial_data_pretraining/llm_asr/demo_speech2text.sh b/examples/industrial_data_pretraining/llm_asr/demo_speech2text.sh
index 57299fc..3eba6d3 100644
--- a/examples/industrial_data_pretraining/llm_asr/demo_speech2text.sh
+++ b/examples/industrial_data_pretraining/llm_asr/demo_speech2text.sh
@@ -62,4 +62,3 @@
 
 }&
 done
-wait
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer/export.py b/examples/industrial_data_pretraining/paraformer/export.py
index a91e9e4..6334e3b 100644
--- a/examples/industrial_data_pretraining/paraformer/export.py
+++ b/examples/industrial_data_pretraining/paraformer/export.py
@@ -13,7 +13,7 @@
     model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
 )
 
-res = model.export(type="torchscripts", quantize=False)
+res = model.export(type="torchscript", quantize=False)
 # res = model.export(type="bladedisc", input=f"{model.model_path}/example/asr_example.wav")
 print(res)
 
diff --git a/funasr/download/runtime_sdk_download_tool.py b/funasr/download/runtime_sdk_download_tool.py
index 96c6735..0db17c7 100644
--- a/funasr/download/runtime_sdk_download_tool.py
+++ b/funasr/download/runtime_sdk_download_tool.py
@@ -10,7 +10,7 @@
     parser.add_argument("--model-name", type=str, required=True)
     parser.add_argument("--export-dir", type=str, required=True)
     parser.add_argument("--export", type=str2bool, default=True, help="whether to export model")
-    parser.add_argument("--type", type=str, default="onnx", help='["onnx", "torch"]')
+    parser.add_argument("--type", type=str, default="onnx", help='["onnx", "torchscript", "bladedisc"]')
     parser.add_argument("--device", type=str, default="cpu", help='["cpu", "cuda"]')
     parser.add_argument("--quantize", type=str2bool, default=False, help="export quantized model")
     parser.add_argument("--fallback-num", type=int, default=0, help="amp fallback number")
@@ -37,11 +37,17 @@
         model_file = os.path.join(model_dir, "model.onnx")
         if args.quantize:
             model_file = os.path.join(model_dir, "model_quant.onnx")
+        if args.type == "torchscript":
+            model_file = os.path.join(model_dir, "model.torchscript")
+            args.device = "cuda"
+        elif args.type == "bladedisc":
+            model_file = os.path.join(model_dir, "model_blade.torchscript")
+            args.device = "cuda"
         if not os.path.exists(model_file):
-            print(".onnx is not exist, begin to export onnx")
+            print("model is not exist, begin to export " + model_file)
             from funasr import AutoModel
 
-            export_model = AutoModel(model=args.model_name, output_dir=output_dir)
+            export_model = AutoModel(model=args.model_name, output_dir=output_dir, device=args.device)
             export_model.export(
                     quantize=args.quantize,
                     type=args.type,
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index 72b150f..a6d0798 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -23,7 +23,7 @@
                 export_dir=export_dir,
                 **kwargs,
             )
-        elif type == "torchscripts":
+        elif type == "torchscript":
             device = "cuda" if torch.cuda.is_available() else "cpu"
             print("Exporting torchscripts on device {}".format(device))
             _torchscripts(m, path=export_dir, device=device)
@@ -100,7 +100,7 @@
             dummy_input = tuple([i.cuda() for i in dummy_input])
 
     model_script = torch.jit.trace(model, dummy_input)
-    model_script.save(os.path.join(path, f"{model.export_name}.torchscripts"))
+    model_script.save(os.path.join(path, f"{model.export_name}.torchscript"))
 
 
 def _bladedisc_opt(model, model_inputs, enable_fp16=True):
@@ -193,4 +193,4 @@
     model.encoder = _bladedisc_opt(model.encoder, input_data[:2])
     model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs))
     model_script = torch.jit.trace(model, input_data)
-    model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscripts"))
+    model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index c252bc7..d8d9473 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -52,7 +52,7 @@
     std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_handle, nn_hotwords_);
     
     // warm up
-    for (size_t i = 0; i < 10; i++)
+    for (size_t i = 0; i < 1; i++)
     {
         FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs, true, decoder_handle);
         if(result){
diff --git a/runtime/onnxruntime/include/com-define.h b/runtime/onnxruntime/include/com-define.h
index 77a7b02..5f71f7b 100644
--- a/runtime/onnxruntime/include/com-define.h
+++ b/runtime/onnxruntime/include/com-define.h
@@ -48,6 +48,7 @@
 #define MODEL_NAME "model.onnx"
 // hotword embedding compile model
 #define MODEL_EB_NAME "model_eb.onnx"
+#define TORCH_MODEL_EB_NAME "model_eb.torchscript"
 #define QUANT_MODEL_NAME "model_quant.onnx"
 #define VAD_CMVN_NAME "am.mvn"
 #define VAD_CONFIG_NAME "config.yaml"
@@ -55,9 +56,9 @@
 // gpu models
 #define INFER_GPU "gpu"
 #define BATCHSIZE "batch-size"
-#define TORCH_MODEL_NAME "model.torchscripts"
-#define TORCH_QUANT_MODEL_NAME "model_quant.torchscripts"
-#define BLADE_MODEL_NAME "model.blade.fp16.pt"
+#define TORCH_MODEL_NAME "model.torchscript"
+#define TORCH_QUANT_MODEL_NAME "model_quant.torchscript"
+#define BLADE_MODEL_NAME "model_blade.torchscript"
 #define BLADEDISC "bladedisc"
 
 #define AM_CMVN_NAME "am.mvn"
diff --git a/runtime/onnxruntime/src/offline-stream.cpp b/runtime/onnxruntime/src/offline-stream.cpp
index 35eb1ba..166d3c9 100644
--- a/runtime/onnxruntime/src/offline-stream.cpp
+++ b/runtime/onnxruntime/src/offline-stream.cpp
@@ -33,7 +33,8 @@
         string am_cmvn_path;
         string am_config_path;
         string token_path;
-        string hw_compile_model_path;
+        string hw_cpu_model_path;
+        string hw_gpu_model_path;
         string seg_dict_path;
     
         if(use_gpu){
@@ -50,33 +51,31 @@
         }
 
         bool enable_hotword = false;
-        hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
+        hw_cpu_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
+        hw_gpu_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_MODEL_EB_NAME);
         seg_dict_path = PathAppend(model_path.at(MODEL_DIR), MODEL_SEG_DICT);
-        if (access(hw_compile_model_path.c_str(), F_OK) == 0) { // if model_eb.onnx exist, hotword enabled
+        if (access(hw_cpu_model_path.c_str(), F_OK) == 0) { // if model_eb.onnx exist, hotword enabled
           enable_hotword = true;
-          asr_handle->InitHwCompiler(hw_compile_model_path, thread_num);
+          asr_handle->InitHwCompiler(hw_cpu_model_path, thread_num);
           asr_handle->InitSegDict(seg_dict_path);
         }
-        if (enable_hotword) {
-          am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
-          if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+        if (use_gpu && access(hw_gpu_model_path.c_str(), F_OK) == 0) { // if model_eb.torchscript exist, hotword enabled
+          enable_hotword = true;
+          asr_handle->InitHwCompiler(hw_gpu_model_path, thread_num);
+          asr_handle->InitSegDict(seg_dict_path);
+        }
+
+        am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
+        if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
             am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
-          }
-        } else {
-          am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
-          if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
-            am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
-          }
-          if(use_gpu){
+        }
+        if(use_gpu){
             am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_MODEL_NAME);
-            if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
-                am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_QUANT_MODEL_NAME);
-            }
             if(model_path.find(BLADEDISC) != model_path.end() && model_path.at(BLADEDISC) == "true"){
                 am_model_path = PathAppend(model_path.at(MODEL_DIR), BLADE_MODEL_NAME);
             }
-          }
         }
+
         am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
         am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
         token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp
index a5f7194..466d80a 100644
--- a/runtime/onnxruntime/src/paraformer-torch.cpp
+++ b/runtime/onnxruntime/src/paraformer-torch.cpp
@@ -50,6 +50,11 @@
         torch::jit::script::Module model = torch::jit::load(am_model, device);
         model_ = std::make_shared<TorchModule>(std::move(model)); 
         LOG(INFO) << "Successfully load model from " << am_model;
+        torch::NoGradGuard no_grad;
+        model_->eval();
+        torch::jit::setGraphExecutorOptimize(false);
+        torch::jit::FusionStrategy static0 = {{torch::jit::FusionBehavior::STATIC, 0}};
+        torch::jit::setFusionStrategy(static0);
     } catch (std::exception const &e) {
         LOG(ERROR) << "Error when load am model: " << am_model << e.what();
         exit(-1);
@@ -100,6 +105,27 @@
 
 void ParaformerTorch::InitHwCompiler(const std::string &hw_model, int thread_num) {
     // TODO
+    torch::DeviceType device = at::kCPU;
+    #ifdef USE_GPU
+    if (!torch::cuda::is_available()) {
+        // LOG(ERROR) << "CUDA is not available! Please check your GPU settings";
+        exit(-1);
+    } else {
+        // LOG(INFO) << "CUDA is available, running on GPU";
+        device = at::kCUDA;
+    }
+    #endif
+
+    try {
+        torch::jit::script::Module model = torch::jit::load(hw_model, device);
+        hw_model_ = std::make_shared<TorchModule>(std::move(model));
+        LOG(INFO) << "Successfully load model from " << hw_model;
+        torch::NoGradGuard no_grad;
+        hw_model_->eval();
+    } catch (std::exception const &e) {
+        LOG(ERROR) << "Error when load hw model: " << hw_model << e.what();
+        exit(-1);
+    }
     use_hotword = true;
 }
 
@@ -111,15 +137,19 @@
 {
     if(vocab){
         delete vocab;
+        vocab = nullptr;
     }
     if(lm_vocab){
         delete lm_vocab;
+        lm_vocab = nullptr;
     }
     if(seg_dict){
         delete seg_dict;
+        seg_dict = nullptr;
     }
     if(phone_set_){
         delete phone_set_;
+        phone_set_ = nullptr;
     }
 }
 
@@ -267,6 +297,9 @@
 
 std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
 {
+    vector<std::string> results;
+    string result="";
+
     WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
     int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
     int32_t feature_dim = lfr_m*in_feat_dim;
@@ -295,8 +328,13 @@
         feats_batch.emplace_back(flattened);
     }
 
-    torch::NoGradGuard no_grad;
-    model_->eval();
+    if(max_frames == 0){
+        for(int index=0; index<batch_in; index++){
+            results.push_back(result);
+        }
+        return results;
+    }
+
     // padding
     std::vector<float> all_feats(batch_in * max_frames * feature_dim);
     for(int index=0; index<batch_in; index++){
@@ -317,8 +355,52 @@
     #endif
     std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
 
-    vector<std::string> results;
+    std::vector<float> batch_embedding;
+    std::vector<float> embedding;
+    try{
+        if (use_hotword) {
+            if(hw_emb.size()<=0){
+                LOG(ERROR) << "hw_emb is null";
+                for(int index=0; index<batch_in; index++){
+                    results.push_back(result);
+                }
+                return results;
+            }
+            
+            embedding.reserve(hw_emb.size() * hw_emb[0].size());
+            for (auto item : hw_emb) {
+                embedding.insert(embedding.end(), item.begin(), item.end());
+            }
+            batch_embedding.reserve(batch_in * embedding.size());
+            for (size_t index = 0; index < batch_in; ++index) {
+                batch_embedding.insert(batch_embedding.end(), embedding.begin(), embedding.end());
+            }
+
+            torch::Tensor tensor_hw_emb =
+                torch::from_blob(batch_embedding.data(),
+                        {batch_in, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())}, torch::kFloat).contiguous();
+            #ifdef USE_GPU
+            tensor_hw_emb = tensor_hw_emb.to(at::kCUDA);
+            #endif
+            inputs.emplace_back(tensor_hw_emb);
+        }
+    }catch (std::exception const &e)
+    {
+        LOG(ERROR)<<e.what();
+        for(int index=0; index<batch_in; index++){
+            results.push_back(result);
+        }
+        return results;
+    }
+
     try {
+        if(inputs.size() == 0){
+            LOG(ERROR) << "inputs of forward is null";
+            for(int index=0; index<batch_in; index++){
+                results.push_back(result);
+            }
+            return results;
+        }
         auto outputs = model_->forward(inputs).toTuple()->elements();
         torch::Tensor am_scores;
         torch::Tensor valid_token_lens;
@@ -329,28 +411,31 @@
         am_scores = outputs[0].toTensor();
         valid_token_lens = outputs[1].toTensor();
         #endif
+
+        torch::Tensor us_alphas_tensor;
+        torch::Tensor us_peaks_tensor;
+        if(outputs.size() == 4){
+            #ifdef USE_GPU
+            us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
+            us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
+            #else
+            us_alphas_tensor = outputs[2].toTensor();
+            us_peaks_tensor = outputs[3].toTensor();
+            #endif
+        }
+
         // timestamp
         for(int index=0; index<batch_in; index++){
-            string result="";
+            result="";
             if(outputs.size() == 4){
-                torch::Tensor us_alphas_tensor;
-                torch::Tensor us_peaks_tensor;
-                #ifdef USE_GPU
-                us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
-                us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
-                #else
-                us_alphas_tensor = outputs[2].toTensor();
-                us_peaks_tensor = outputs[3].toTensor();
-                #endif
-
                 float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
-                std::vector<float> us_alphas(paraformer_length[index]);
+                std::vector<float> us_alphas(paraformer_length[index]*3);
                 for (int i = 0; i < us_alphas.size(); i++) {
                     us_alphas[i] = us_alphas_data[i];
                 }
 
                 float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
-                std::vector<float> us_peaks(paraformer_length[index]);
+                std::vector<float> us_peaks(paraformer_length[index]*3);
                 for (int i = 0; i < us_peaks.size(); i++) {
                     us_peaks[i] = us_peaks_data[i];
                 }
@@ -387,8 +472,98 @@
 }
 
 std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
-    // TODO
-    std::vector<std::vector<float>> result(1, std::vector<float>(10, 0.0f));
+    int embedding_dim = encoder_size;
+    std::vector<std::vector<float>> hw_emb;
+    if (!use_hotword) {
+        std::vector<float> vec(embedding_dim, 0);
+        hw_emb.push_back(vec);
+        return hw_emb;
+    }
+    int max_hotword_len = 10;
+    std::vector<int32_t> hotword_matrix;
+    std::vector<int32_t> lengths;
+    int hotword_size = 1;
+    int real_hw_size = 0;
+    if (!hotwords.empty()) {
+      std::vector<std::string> hotword_array = split(hotwords, ' ');
+      hotword_size = hotword_array.size() + 1;
+      hotword_matrix.reserve(hotword_size * max_hotword_len);
+      for (auto hotword : hotword_array) {
+        std::vector<std::string> chars;
+        if (EncodeConverter::IsAllChineseCharactor((const U8CHAR_T*)hotword.c_str(), hotword.size())) {
+          KeepChineseCharacterAndSplit(hotword, chars);
+        } else {
+          // for english
+          std::vector<std::string> words = split(hotword, ' ');
+          for (auto word : words) {
+            std::vector<string> tokens = seg_dict->GetTokensByWord(word);
+            chars.insert(chars.end(), tokens.begin(), tokens.end());
+          }
+        }
+        if(chars.size()==0){
+            continue;
+        }
+        std::vector<int32_t> hw_vector(max_hotword_len, 0);
+        int vector_len = std::min(max_hotword_len, (int)chars.size());
+        int chs_oov = false;
+        for (int i=0; i<vector_len; i++) {
+          hw_vector[i] = phone_set_->String2Id(chars[i]);
+          if(hw_vector[i] == -1){
+            chs_oov = true;
+            break;
+          }
+        }
+        if(chs_oov){
+          LOG(INFO) << "OOV: " << hotword;
+          continue;
+        }
+        LOG(INFO) << hotword;
+        lengths.push_back(vector_len);
+        real_hw_size += 1;
+        hotword_matrix.insert(hotword_matrix.end(), hw_vector.begin(), hw_vector.end());
+      }
+      hotword_size = real_hw_size + 1;
+    }
+    std::vector<int32_t> blank_vec(max_hotword_len, 0);
+    blank_vec[0] = 1;
+    hotword_matrix.insert(hotword_matrix.end(), blank_vec.begin(), blank_vec.end());
+    lengths.push_back(1);
+
+    torch::Tensor feats =
+        torch::from_blob(hotword_matrix.data(),
+                {hotword_size, max_hotword_len}, torch::kInt32).contiguous();
+
+    // 2. forward
+    #ifdef USE_GPU
+    feats = feats.to(at::kCUDA);
+    #endif
+    std::vector<torch::jit::IValue> inputs = {feats};
+    std::vector<std::vector<float>> result;
+    try {
+        auto output = hw_model_->forward(inputs);
+        torch::Tensor emb_tensor;
+        #ifdef USE_GPU
+        emb_tensor = output.toTensor().to(at::kCPU);
+        #else
+        emb_tensor = output.toTensor();
+        #endif
+        assert(emb_tensor.size(0) == max_hotword_len);
+        assert(emb_tensor.size(1) == hotword_size);
+        embedding_dim = emb_tensor.size(2);
+
+        float* floatData = emb_tensor.data_ptr<float>();
+        for (int j = 0; j < hotword_size; j++)
+        {
+            int start_pos = hotword_size * (lengths[j] - 1) * embedding_dim + j * embedding_dim;
+            std::vector<float> embedding;
+            embedding.insert(embedding.begin(), floatData + start_pos, floatData + start_pos + embedding_dim);
+            result.push_back(embedding);
+        }
+    }
+    catch (std::exception const &e)
+    {
+        LOG(ERROR)<<e.what();
+    }
     return result;
 }
 
diff --git a/runtime/onnxruntime/src/paraformer-torch.h b/runtime/onnxruntime/src/paraformer-torch.h
index 74ac315..bea33db 100644
--- a/runtime/onnxruntime/src/paraformer-torch.h
+++ b/runtime/onnxruntime/src/paraformer-torch.h
@@ -36,6 +36,7 @@
 
         using TorchModule = torch::jit::script::Module;
         std::shared_ptr<TorchModule> model_ = nullptr;
+        std::shared_ptr<TorchModule> hw_model_ = nullptr;
         std::vector<torch::Tensor> encoder_outs_;
         bool use_hotword;
 
diff --git a/runtime/onnxruntime/src/util.cpp b/runtime/onnxruntime/src/util.cpp
index a12570b..483795e 100644
--- a/runtime/onnxruntime/src/util.cpp
+++ b/runtime/onnxruntime/src/util.cpp
@@ -870,6 +870,15 @@
                 sum -=(1.0 - 1e-4);
             }            
         }
+        // fix case: sum > 1
+        int cif_idx = cif_peak.size()-1;
+        while(sum>=1.0 - 1e-4 && cif_idx >= 0 ){
+            if(cif_peak[cif_idx] < 1.0 - 1e-4){
+                cif_peak[cif_idx] = sum;
+                sum -=(1.0 - 1e-4);
+            }
+            cif_idx--;
+        }
 
         fire_place.clear();
         for (int i = 0; i < num_frames; i++) {
diff --git a/runtime/python/libtorch/README.md b/runtime/python/libtorch/README.md
index a96846e..1d15d2b 100644
--- a/runtime/python/libtorch/README.md
+++ b/runtime/python/libtorch/README.md
@@ -41,7 +41,7 @@
 
 ## Run the demo
 
-- Model_dir: the model path, which contains `model.torchscripts`, `config.yaml`, `am.mvn`.
+- Model_dir: the model path, which contains `model.torchscript`, `config.yaml`, `am.mvn`.
 - Input: wav formt file, support formats: `str, np.ndarray, List[str]`
 - Output: `List[str]`: recognition result.
 - Example:
diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index 5fa3cc9..16c0406 100644
--- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -46,11 +46,11 @@
                     model_dir
                 )
 
-        model_file = os.path.join(model_dir, "model.torchscripts")
+        model_file = os.path.join(model_dir, "model.torchscript")
         if quantize:
-            model_file = os.path.join(model_dir, "model_quant.torchscripts")
+            model_file = os.path.join(model_dir, "model_quant.torchscript")
         if not os.path.exists(model_file):
-            print(".torchscripts does not exist, begin to export torchscripts")
+            print(".torchscripts does not exist, begin to export torchscript")
             try:
                 from funasr import AutoModel
             except:
@@ -268,11 +268,11 @@
                 )
 
         if quantize:
-            model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscripts")
-            model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscripts")
+            model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscript")
+            model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscript")
         else:
-            model_bb_file = os.path.join(model_dir, "model_bb.torchscripts")
-            model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
+            model_bb_file = os.path.join(model_dir, "model_bb.torchscript")
+            model_eb_file = os.path.join(model_dir, "model_eb.torchscript")
 
         if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
             print(".onnx does not exist, begin to export onnx")
@@ -282,7 +282,7 @@
                 raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
             model = AutoModel(model=model_dir)
-            model_dir = model.export(type="torchscripts", quantize=quantize, **kwargs)
+            model_dir = model.export(type="torchscript", quantize=quantize, **kwargs)
 
         config_file = os.path.join(model_dir, "config.yaml")
         cmvn_file = os.path.join(model_dir, "am.mvn")
diff --git a/runtime/websocket/bin/funasr-wss-server.cpp b/runtime/websocket/bin/funasr-wss-server.cpp
index 0d475da..3c5b81c 100644
--- a/runtime/websocket/bin/funasr-wss-server.cpp
+++ b/runtime/websocket/bin/funasr-wss-server.cpp
@@ -45,7 +45,7 @@
         false, "/workspace/models", "string");
     TCLAP::ValueArg<std::string> model_dir(
         "", MODEL_DIR,
-        "default: /workspace/models/asr, the asr model path, which contains model_quant.onnx, config.yaml, am.mvn",
+        "default: /workspace/models/asr, the asr model path, which contains *.onnx/*.torchscript, config.yaml, am.mvn",
         false, "/workspace/models/asr", "string");
     TCLAP::ValueArg<std::string> model_revision(
         "", "model-revision",
@@ -67,7 +67,7 @@
     TCLAP::ValueArg<std::string> vad_revision(
         "", "vad-revision",
         "VAD model revision",
-        false, "v2.0.4", "string");
+        false, "v2.0.6", "string");
     TCLAP::ValueArg<std::string> vad_quant(
         "", VAD_QUANT,
         "true (Default), load the model of model_quant.onnx in vad_dir. If set "
@@ -198,8 +198,9 @@
         std::string s_punc_quant = model_path[PUNC_QUANT];
         std::string s_itn_path = model_path[ITN_DIR];
         std::string s_lm_path = model_path[LM_DIR];
+        std::string s_blade = model_path[BLADEDISC];
 
-        std::string python_cmd = "python -m funasr.download.runtime_sdk_download_tool --type onnx ";
+        std::string python_cmd = "python -m funasr.download.runtime_sdk_download_tool ";
 
         if(vad_dir.isSet() && !s_vad_path.empty()){
             std::string python_cmd_vad;
@@ -208,12 +209,12 @@
 
             if (access(s_vad_path.c_str(), F_OK) == 0){
                 // local
-                python_cmd_vad = python_cmd + " --quantize " + s_vad_quant + " --model-name " + s_vad_path + " --export-dir ./ " + " --model_revision " + model_path["vad-revision"];
+                python_cmd_vad = python_cmd + " --type onnx " + " --quantize " + s_vad_quant + " --model-name " + s_vad_path + " --export-dir ./ " + " --model_revision " + model_path["vad-revision"];
                 down_vad_path  = s_vad_path;
             }else{
                 // modelscope
                 LOG(INFO) << "Download model: " <<  s_vad_path << " from modelscope: ";
-                python_cmd_vad = python_cmd + " --quantize " + s_vad_quant + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["vad-revision"];
+                python_cmd_vad = python_cmd + " --type onnx " + " --quantize " + s_vad_quant + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["vad-revision"];
                 down_vad_path  = s_download_model_dir+"/"+s_vad_path;
             }
                 
@@ -241,6 +242,7 @@
             std::string python_cmd_asr;
             std::string down_asr_path;
             std::string down_asr_model;
+            std::string model_type = "onnx";
 
             // modify model-revision by model name
             size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
@@ -260,24 +262,39 @@
                 s_lm_path="";
             }
 
+            if (use_gpu_){
+                model_type = "torchscript";
+                if (s_blade=="true" || s_blade=="True" || s_blade=="TRUE"){
+                    model_type = "bladedisc";
+                }
+            }
+
             if (access(s_asr_path.c_str(), F_OK) == 0){
                 // local
-                python_cmd_asr = python_cmd + " --quantize " + s_asr_quant + " --model-name " + s_asr_path + " --export-dir ./ " + " --model_revision " + model_path["model-revision"];
+                python_cmd_asr = python_cmd + " --type " + model_type + " --quantize " + s_asr_quant + " --model-name " + s_asr_path + " --export-dir ./ " + " --model_revision " + model_path["model-revision"];
                 down_asr_path  = s_asr_path;
             }else{
                 // modelscope
                 LOG(INFO) << "Download model: " <<  s_asr_path << " from modelscope: ";
-                python_cmd_asr = python_cmd + " --quantize " + s_asr_quant + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["model-revision"];
+                python_cmd_asr = python_cmd + " --type " + model_type + " --quantize " + s_asr_quant + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["model-revision"];
                 down_asr_path  = s_download_model_dir+"/"+s_asr_path;
             }
-                
-            int ret = system(python_cmd_asr.c_str());
-            if(ret !=0){
-                LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
-            }
+            
             down_asr_model = down_asr_path+"/model_quant.onnx";
             if(s_asr_quant=="false" || s_asr_quant=="False" || s_asr_quant=="FALSE"){
                 down_asr_model = down_asr_path+"/model.onnx";
+            }
+
+            if (use_gpu_){
+                down_asr_model = down_asr_path+"/model.torchscript";
+                if (s_blade=="true" || s_blade=="True" || s_blade=="TRUE"){
+                    down_asr_model = down_asr_path+"/model_blade.torchscript";
+                }
+            }
+
+            int ret = system(python_cmd_asr.c_str());
+            if(ret !=0){
+                LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
             }
 
             if (access(down_asr_model.c_str(), F_OK) != 0){
@@ -298,7 +315,7 @@
 
             if (access(s_itn_path.c_str(), F_OK) == 0) {
                 // local
-                python_cmd_itn = python_cmd + " --model-name " + s_itn_path +
+                python_cmd_itn = python_cmd + " --type onnx " + " --model-name " + s_itn_path +
                                     " --export-dir ./ " + " --model_revision " +
                                     model_path["itn-revision"] + " --export False ";
                 down_itn_path = s_itn_path;
@@ -306,7 +323,7 @@
                 // modelscope
                 LOG(INFO) << "Download model: " << s_itn_path
                             << " from modelscope : "; 
-                python_cmd_itn = python_cmd + " --model-name " +
+                python_cmd_itn = python_cmd + " --type onnx " + " --model-name " +
                         s_itn_path +
                         " --export-dir " + s_download_model_dir +
                         " --model_revision " + model_path["itn-revision"]
@@ -340,7 +357,7 @@
 
             if (access(s_lm_path.c_str(), F_OK) == 0) {
                 // local
-                python_cmd_lm = python_cmd + "--quantize " + s_punc_quant + " --model-name " + s_lm_path +
+                python_cmd_lm = python_cmd + " --type onnx " + " --model-name " + s_lm_path +
                                     " --export-dir ./ " + " --model_revision " +
                                     model_path["lm-revision"] + " --export False ";
                 down_lm_path = s_lm_path;
@@ -348,7 +365,7 @@
                 // modelscope
                 LOG(INFO) << "Download model: " << s_lm_path
                             << " from modelscope : "; 
-                python_cmd_lm = python_cmd + " --quantize " + s_punc_quant + " --model-name " +
+                python_cmd_lm = python_cmd + " --type onnx " + " --model-name " +
                         s_lm_path +
                         " --export-dir " + s_download_model_dir +
                         " --model_revision " + model_path["lm-revision"]
@@ -383,12 +400,12 @@
 
             if (access(s_punc_path.c_str(), F_OK) == 0){
                 // local
-                python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir ./ " + " --model_revision " + model_path["punc-revision"];
+                python_cmd_punc = python_cmd + " --type onnx " + "--quantize " + s_punc_quant + " --model-name " + s_punc_path + " --export-dir ./ " + " --model_revision " + model_path["punc-revision"];
                 down_punc_path  = s_punc_path;
             }else{
                 // modelscope
                 LOG(INFO) << "Download model: " <<  s_punc_path << " from modelscope: ";
-                python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["punc-revision"];
+                python_cmd_punc = python_cmd + " --type onnx " + "--quantize " + s_punc_quant + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["punc-revision"];
                 down_punc_path  = s_download_model_dir+"/"+s_punc_path;
             }
                 

--
Gitblit v1.9.1