From ade08818b7a579aac75182b906a5bd3b8126411c Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 27 五月 2024 15:46:26 +0800
Subject: [PATCH] Merge branch 'dev_batch' into main

---
 runtime/onnxruntime/src/precomp.h                                     |    3 
 runtime/onnxruntime/src/paraformer.h                                  |    4 
 runtime/websocket/bin/websocket-server.cpp                            |    4 
 runtime/websocket/bin/websocket-server.h                              |    2 
 runtime/onnxruntime/include/com-define.h                              |    9 
 runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp |    4 
 runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp                   |   15 
 runtime/run_server.sh                                                 |    7 
 runtime/onnxruntime/src/audio.cpp                                     |  106 +++++
 runtime/onnxruntime/src/offline-stream.cpp                            |   29 +
 runtime/onnxruntime/bin/CMakeLists.txt                                |   10 
 runtime/onnxruntime/src/paraformer-torch.h                            |   96 +++++
 runtime/onnxruntime/src/paraformer.cpp                                |   24 
 runtime/onnxruntime/bin/funasr-onnx-offline.cpp                       |   40 ++
 runtime/onnxruntime/CMakeLists.txt                                    |   12 
 runtime/onnxruntime/src/funasrruntime.cpp                             |  178 ++++++--
 runtime/onnxruntime/include/funasrruntime.h                           |    2 
 runtime/websocket/bin/funasr-wss-server.cpp                           |   14 
 runtime/onnxruntime/src/CMakeLists.txt                                |   11 
 runtime/websocket/CMakeLists.txt                                      |   20 
 runtime/onnxruntime/include/model.h                                   |   12 
 runtime/onnxruntime/include/audio.h                                   |    4 
 runtime/onnxruntime/include/offline-stream.h                          |    4 
 runtime/onnxruntime/src/paraformer-torch.cpp                          |  415 +++++++++++++++++++++++
 runtime/websocket/bin/CMakeLists.txt                                  |    8 
 25 files changed, 936 insertions(+), 97 deletions(-)

diff --git a/runtime/onnxruntime/CMakeLists.txt b/runtime/onnxruntime/CMakeLists.txt
index d8e623e..cf10f54 100644
--- a/runtime/onnxruntime/CMakeLists.txt
+++ b/runtime/onnxruntime/CMakeLists.txt
@@ -4,6 +4,7 @@
 
 option(ENABLE_GLOG "Whether to build glog" ON)
 option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
+option(GPU "Whether to build with GPU" OFF)
 
 # set(CMAKE_CXX_STANDARD 11)
 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
@@ -49,6 +50,17 @@
 include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi)
 include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
 
+if(GPU)
+    add_definitions(-DUSE_GPU)
+    set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
+    set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
+    include_directories(${TORCH_DIR}/include)
+    include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
+    link_directories(${TORCH_DIR}/lib)
+    link_directories(${TORCH_BLADE_DIR})
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
+endif()
+
 if(ENABLE_GLOG)
     include_directories(${PROJECT_SOURCE_DIR}/third_party/glog/src)
     set(BUILD_TESTING OFF)
diff --git a/runtime/onnxruntime/bin/CMakeLists.txt b/runtime/onnxruntime/bin/CMakeLists.txt
index c91fbc4..0ca7f1e 100644
--- a/runtime/onnxruntime/bin/CMakeLists.txt
+++ b/runtime/onnxruntime/bin/CMakeLists.txt
@@ -10,33 +10,43 @@
 endif()
 
 add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-offline PRIVATE "-Wl,--no-as-needed")
 target_link_libraries(funasr-onnx-offline PUBLIC funasr)
 
 add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-offline-vad PRIVATE "-Wl,--no-as-needed")
 target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
 
 add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-online-vad PRIVATE "-Wl,--no-as-needed")
 target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
 
 add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-online-asr PRIVATE "-Wl,--no-as-needed")
 target_link_libraries(funasr-onnx-online-asr PUBLIC funasr)
 
 add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-offline-punc PRIVATE "-Wl,--no-as-needed")
 target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
 
 add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-online-punc PRIVATE "-Wl,--no-as-needed")
 target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
 
 add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-offline-rtf PRIVATE "-Wl,--no-as-needed")
 target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
 
 add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-2pass PRIVATE "-Wl,--no-as-needed")
 target_link_libraries(funasr-onnx-2pass PUBLIC funasr)
 
 add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-2pass-rtf PRIVATE "-Wl,--no-as-needed")
 target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr)
 
 add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-online-rtf PRIVATE "-Wl,--no-as-needed")
 target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr)
 
 # include_directories(${FFMPEG_DIR}/include)
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index 5722693..c252bc7 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -52,7 +52,7 @@
     std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_handle, nn_hotwords_);
     
     // warm up
-    for (size_t i = 0; i < 1; i++)
+    for (size_t i = 0; i < 10; i++)
     {
         FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs, true, decoder_handle);
         if(result){
@@ -127,6 +127,7 @@
     TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0");
     TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
     TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
+    TCLAP::ValueArg<std::string>    bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
     TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
     TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
     TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
@@ -140,11 +141,14 @@
 
     TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
     TCLAP::ValueArg<std::int32_t>   audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
-    TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
+    TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", false, 1, "int32_t");
     TCLAP::ValueArg<std::string>    hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 闃块噷宸村反 20)", false, "", "string");
+    TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
+    TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
 
     cmd.add(model_dir);
     cmd.add(quantize);
+    cmd.add(bladedisc);
     cmd.add(vad_dir);
     cmd.add(vad_quant);
     cmd.add(punc_dir);
@@ -159,11 +163,14 @@
     cmd.add(wav_path);
     cmd.add(audio_fs);
     cmd.add(thread_num);
+    cmd.add(use_gpu);
+    cmd.add(batch_size);
     cmd.parse(argc, argv);
 
     std::map<std::string, std::string> model_path;
     GetValue(model_dir, MODEL_DIR, model_path);
     GetValue(quantize, QUANTIZE, model_path);
+    GetValue(bladedisc, BLADEDISC, model_path);
     GetValue(vad_dir, VAD_DIR, model_path);
     GetValue(vad_quant, VAD_QUANT, model_path);
     GetValue(punc_dir, PUNC_DIR, model_path);
@@ -175,7 +182,9 @@
 
     struct timeval start, end;
     gettimeofday(&start, nullptr);
-    FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1);
+    bool use_gpu_ = use_gpu.getValue();
+    int batch_size_ = batch_size.getValue();
+    FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1, use_gpu_, batch_size_);
 
     if (!asr_handle)
     {
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
index edb83bd..764c581 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -19,6 +19,7 @@
 #include "com-define.h"
 #include <unordered_map>
 #include "util.h"
+#include "audio.h"
 using namespace std;
 
 bool is_target_file(const std::string& filename, const std::string target) {
@@ -44,6 +45,7 @@
     TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
     TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
     TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
+    TCLAP::ValueArg<std::string>    bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
     TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
     TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
     TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
@@ -57,9 +59,12 @@
     TCLAP::ValueArg<std::string>    wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
     TCLAP::ValueArg<std::int32_t>   audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
     TCLAP::ValueArg<std::string>    hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 闃块噷宸村反 20)", false, "", "string");
+    TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
+    TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
 
     cmd.add(model_dir);
     cmd.add(quantize);
+    cmd.add(bladedisc);
     cmd.add(vad_dir);
     cmd.add(vad_quant);
     cmd.add(punc_dir);
@@ -73,11 +78,14 @@
     cmd.add(wav_path);
     cmd.add(audio_fs);
     cmd.add(hotword);
+    cmd.add(use_gpu);
+    cmd.add(batch_size);
     cmd.parse(argc, argv);
 
     std::map<std::string, std::string> model_path;
     GetValue(model_dir, MODEL_DIR, model_path);
     GetValue(quantize, QUANTIZE, model_path);
+    GetValue(bladedisc, BLADEDISC, model_path);
     GetValue(vad_dir, VAD_DIR, model_path);
     GetValue(vad_quant, VAD_QUANT, model_path);
     GetValue(punc_dir, PUNC_DIR, model_path);
@@ -89,7 +97,9 @@
     struct timeval start, end;
     gettimeofday(&start, nullptr);
     int thread_num = 1;
-    FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
+    bool use_gpu_ = use_gpu.getValue();
+    int batch_size_ = batch_size.getValue();
+    FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num, use_gpu_, batch_size_);
 
     if (!asr_hanlde)
     {
@@ -156,8 +166,34 @@
     for (int i = 0; i < wav_list.size(); i++) {
         auto& wav_file = wav_list[i];
         auto& wav_id = wav_ids[i];
+
+        // For debug:begin
+        int32_t sampling_rate_ = audio_fs.getValue();
+        funasr::Audio audio(1);
+		if(is_target_file(wav_file.c_str(), "wav")){
+			if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
+				LOG(ERROR)<<"Failed to load "<< wav_file;
+                exit(-1);
+            }
+		}else if(is_target_file(wav_file.c_str(), "pcm")){
+			if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
+				LOG(ERROR)<<"Failed to load "<< wav_file;
+                exit(-1);
+            }
+		}else{
+			if (!audio.FfmpegLoad(wav_file.c_str(), true)){
+				LOG(ERROR)<<"Failed to load "<< wav_file;
+                exit(-1);
+            }
+		}
+        char* speech_buff = audio.GetSpeechChar();
+        int buff_len = audio.GetSpeechLen()*2;
+
         gettimeofday(&start, nullptr);
-        FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), true, decoder_handle);
+        FUNASR_RESULT result=FunOfflineInferBuffer(asr_hanlde, speech_buff, buff_len, RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), "pcm", true, decoder_handle);
+        // For debug:end
+
+        // FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), true, decoder_handle);
         gettimeofday(&end, nullptr);
         seconds = (end.tv_sec - start.tv_sec);
         taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
diff --git a/runtime/onnxruntime/include/audio.h b/runtime/onnxruntime/include/audio.h
index 9edd9c9..3011050 100644
--- a/runtime/onnxruntime/include/audio.h
+++ b/runtime/onnxruntime/include/audio.h
@@ -83,9 +83,11 @@
     int FetchTpass(AudioFrame *&frame);
     int Fetch(float *&dout, int &len, int &flag);
     int Fetch(float *&dout, int &len, int &flag, float &start_time);
+    int Fetch(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
+    int FetchDynamic(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
     void Padding();
     void Split(OfflineStream* offline_streamj);
-    void CutSplit(OfflineStream* offline_streamj);
+    void CutSplit(OfflineStream* offline_streamj, std::vector<int> &index_vector);
     void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
     void Split(VadModel* vad_obj, int chunk_len, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
     float GetTimeLen();
diff --git a/runtime/onnxruntime/include/com-define.h b/runtime/onnxruntime/include/com-define.h
index d4edd5b..77a7b02 100644
--- a/runtime/onnxruntime/include/com-define.h
+++ b/runtime/onnxruntime/include/com-define.h
@@ -51,6 +51,15 @@
 #define QUANT_MODEL_NAME "model_quant.onnx"
 #define VAD_CMVN_NAME "am.mvn"
 #define VAD_CONFIG_NAME "config.yaml"
+
+// gpu models
+#define INFER_GPU "gpu"
+#define BATCHSIZE "batch-size"
+#define TORCH_MODEL_NAME "model.torchscripts"
+#define TORCH_QUANT_MODEL_NAME "model_quant.torchscripts"
+#define BLADE_MODEL_NAME "model.blade.fp16.pt"
+#define BLADEDISC "bladedisc"
+
 #define AM_CMVN_NAME "am.mvn"
 #define AM_CONFIG_NAME "config.yaml"
 #define LM_CONFIG_NAME "config.yaml"
diff --git a/runtime/onnxruntime/include/funasrruntime.h b/runtime/onnxruntime/include/funasrruntime.h
index cff617f..cc9ba3d 100644
--- a/runtime/onnxruntime/include/funasrruntime.h
+++ b/runtime/onnxruntime/include/funasrruntime.h
@@ -96,7 +96,7 @@
 _FUNASRAPI void					CTTransformerUninit(FUNASR_HANDLE handle);
 
 //OfflineStream
-_FUNASRAPI FUNASR_HANDLE  	FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
+_FUNASRAPI FUNASR_HANDLE  	FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
 _FUNASRAPI void         	FunOfflineReset(FUNASR_HANDLE handle, FUNASR_DEC_HANDLE dec_handle=nullptr);
 // buffer
 _FUNASRAPI FUNASR_RESULT	FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, 
diff --git a/runtime/onnxruntime/include/model.h b/runtime/onnxruntime/include/model.h
index f5c4027..1064c4c 100644
--- a/runtime/onnxruntime/include/model.h
+++ b/runtime/onnxruntime/include/model.h
@@ -5,6 +5,10 @@
 #include <string>
 #include <map>
 #include "funasrruntime.h"
+#include "vocab.h"
+#include "phone-set.h"
+#include "fst/fstlib.h"
+#include "fst/symbol-table.h"
 namespace funasr {
 class Model {
   public:
@@ -18,13 +22,19 @@
     virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
     virtual void InitFstDecoder(){};
     virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
+    virtual std::vector<std::string> Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1)
+      {return std::vector<string>();};
     virtual std::string Rescoring() = 0;
     virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
     virtual void InitSegDict(const std::string &seg_dict_model){};
     virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
     virtual std::string GetLang(){return "";};
     virtual int GetAsrSampleRate() = 0;
-
+    virtual void SetBatchSize(int batch_size) {};
+    virtual int GetBatchSize() {return 0;};
+    virtual Vocab* GetVocab() {return nullptr;};
+    virtual Vocab* GetLmVocab() {return nullptr;};
+    virtual PhoneSet* GetPhoneSet() {return nullptr;};
 };
 
 Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
diff --git a/runtime/onnxruntime/include/offline-stream.h b/runtime/onnxruntime/include/offline-stream.h
index f63de74..cc0f1c4 100644
--- a/runtime/onnxruntime/include/offline-stream.h
+++ b/runtime/onnxruntime/include/offline-stream.h
@@ -14,7 +14,7 @@
 namespace funasr {
 class OfflineStream {
   public:
-    OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
+    OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
     ~OfflineStream(){};
 
     std::unique_ptr<VadModel> vad_handle= nullptr;
@@ -33,6 +33,6 @@
     bool use_itn=false;
 };
 
-OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1);
+OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1, bool use_gpu=false, int batch_size=1);
 } // namespace funasr
 #endif
diff --git a/runtime/onnxruntime/src/CMakeLists.txt b/runtime/onnxruntime/src/CMakeLists.txt
index 9eac2b6..6d6b66b 100644
--- a/runtime/onnxruntime/src/CMakeLists.txt
+++ b/runtime/onnxruntime/src/CMakeLists.txt
@@ -1,6 +1,11 @@
 
 file(GLOB files1 "*.cpp")
+list(REMOVE_ITEM files1 "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
 set(files ${files1})
+
+if(GPU)
+    set(files ${files} "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
+endif()
 
 message("files: "${files})
 
@@ -25,7 +30,11 @@
     include_directories(${FFMPEG_DIR}/include)
 endif()
 
+if(GPU)
+    set(TORCH_DEPS torch torch_cuda torch_cpu c10 c10_cuda torch_blade ral_base_context)
+endif()
+
 #message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
 include_directories(${CMAKE_SOURCE_DIR}/include)
 include_directories(${CMAKE_SOURCE_DIR}/third_party)
-target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
+target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS} ${TORCH_DEPS})
diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp
index 0135ab4..22a9ecd 100644
--- a/runtime/onnxruntime/src/audio.cpp
+++ b/runtime/onnxruntime/src/audio.cpp
@@ -1023,6 +1023,90 @@
     }
 }
 
+int Audio::Fetch(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
+{
+    batch_in = std::min((int)frame_queue.size(), batch_size);
+    if (batch_in == 0){
+        return 0;
+    } else{
+        // init
+        dout = new float*[batch_in];
+        len = new int[batch_in];
+        flag = new int[batch_in];
+        start_time = new float[batch_in];
+
+        for(int idx=0; idx < batch_in; idx++){
+            AudioFrame *frame = frame_queue.front();
+            frame_queue.pop();
+
+            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
+            dout[idx] = speech_data + frame->GetStart();
+            len[idx] = frame->GetLen();
+            delete frame;
+            flag[idx] = S_END;
+        }
+        return 1;
+    }
+}
+
+int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
+{
+    //compute batch size
+    queue<AudioFrame *> frame_batch;
+    int max_acc = 300*1000*seg_sample;
+    int max_sent = 60*1000*seg_sample;
+    int bs_acc = 0;
+    int max_len = 0;
+    int max_batch = 1;
+    #ifdef USE_GPU
+        max_batch = batch_size;
+    #endif
+    max_batch = std::min(max_batch, (int)frame_queue.size());
+
+    for(int idx=0; idx < max_batch; idx++){
+        AudioFrame *frame = frame_queue.front();
+        int length = frame->GetLen();
+        if(length >= max_sent){
+            if(bs_acc == 0){
+                bs_acc++;
+                frame_batch.push(frame);
+                frame_queue.pop();                
+            }
+            break;
+        }
+        max_len = std::max(max_len, frame->GetLen());
+        if(max_len*(bs_acc+1) > max_acc){
+            break;
+        }
+        bs_acc++;
+        frame_batch.push(frame);
+        frame_queue.pop();
+    }
+
+    batch_in = (int)frame_batch.size();
+    if (batch_in == 0){
+        return 0;
+    } else{
+        // init
+        dout = new float*[batch_in];
+        len = new int[batch_in];
+        flag = new int[batch_in];
+        start_time = new float[batch_in];
+
+        for(int idx=0; idx < batch_in; idx++){
+            AudioFrame *frame = frame_batch.front();
+            frame_batch.pop();
+
+            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
+            dout[idx] = speech_data + frame->GetStart();
+            len[idx] = frame->GetLen();
+            delete frame;
+            flag[idx] = S_END;
+        }
+        return 1;
+    }
+}
+
 void Audio::Padding()
 {
     float num_samples = speech_len;
@@ -1085,7 +1169,7 @@
     }
 }
 
-void Audio::CutSplit(OfflineStream* offline_stream)
+void Audio::CutSplit(OfflineStream* offline_stream, std::vector<int> &index_vector)
 {
     std::unique_ptr<VadModel> vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(offline_stream->vad_handle).get());
     AudioFrame *frame;
@@ -1112,6 +1196,7 @@
     }    
 
     int speech_start_i = -1, speech_end_i =-1;
+    std::vector<AudioFrame*> vad_frames;
     for(vector<int> vad_segment:vad_segments)
     {
         if(vad_segment.size() != 2){
@@ -1126,16 +1211,31 @@
         }
 
         if(speech_start_i!=-1 && speech_end_i!=-1){
-            frame = new AudioFrame();
             int start = speech_start_i*seg_sample;
             int end = speech_end_i*seg_sample;
+            frame = new AudioFrame(end-start);
             frame->SetStart(start);
             frame->SetEnd(end);
-            frame_queue.push(frame);
+            vad_frames.push_back(frame);
             frame = nullptr;
             speech_start_i=-1;
             speech_end_i=-1;
         }
+
+    }
+    // sort
+    {
+        index_vector.clear();
+        index_vector.resize(vad_frames.size());
+        for (int i = 0; i < index_vector.size(); ++i) {
+            index_vector[i] = i;
+        }
+        std::sort(index_vector.begin(), index_vector.end(), [&vad_frames](const int a, const int b) {
+            return vad_frames[a]->len < vad_frames[b]->len;
+        });
+        for (int idx : index_vector) {
+            frame_queue.push(vad_frames[idx]);
+        }
     }
 }
 
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 4bc64af..d235e6f 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -33,9 +33,9 @@
 		return mm;
 	}
 
-	_FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num)
+	_FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
 	{
-		funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num);
+		funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu, batch_size);
 		return mm;
 	}
 
@@ -74,16 +74,11 @@
 		if(p_result->snippet_time == 0){
 			return p_result;
 		}
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
+
 		while (audio.Fetch(buff, len, flag) > 0) {
 			string msg = recog_obj->Forward(buff, len, input_finished);
 			p_result->msg += msg;
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
-
 		return p_result;
 	}
 
@@ -109,8 +104,6 @@
 		float* buff;
 		int len;
 		int flag = 0;
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
 		funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
 		p_result->snippet_time = audio.GetTimeLen();
 		if(p_result->snippet_time == 0){
@@ -119,11 +112,7 @@
 		while (audio.Fetch(buff, len, flag) > 0) {
 			string msg = recog_obj->Forward(buff, len, true);
 			p_result->msg += msg;
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
-
 		return p_result;
 	}
 
@@ -244,26 +233,53 @@
 		if(p_result->snippet_time == 0){
             return p_result;
         }
+		std::vector<int> index_vector={0};
+		int msg_idx = 0;
 		if(offline_stream->UseVad()){
-			audio.CutSplit(offline_stream);
+			audio.CutSplit(offline_stream, index_vector);
 		}
+		std::vector<string> msgs(index_vector.size());
+		std::vector<float> msg_stimes(index_vector.size());
 
-		float* buff;
-		int len;
-		int flag = 0;
+		float** buff;
+		int* len;
+		int* flag;
+		float* start_time;
+		int batch_size = offline_stream->asr_handle->GetBatchSize();
+		int batch_in = 0;
 
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
-		float start_time = 0.0;
 		std::string cur_stamp = "[";
 		std::string lang = (offline_stream->asr_handle)->GetLang();
-		while (audio.Fetch(buff, len, flag, start_time) > 0) {
+		while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
 			// dec reset
 			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
 			if (wfst_decoder){
 				wfst_decoder->StartUtterance();
 			}
-			string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
+			vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
+			for(int idx=0; idx<batch_in; idx++){
+				string msg = msg_batch[idx];
+				if(msg_idx < index_vector.size()){
+					msgs[index_vector[msg_idx]] = msg;
+					msg_stimes[index_vector[msg_idx]] = start_time[idx];
+					msg_idx++;
+				}else{
+					LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
+				}				
+			}
+
+			// release
+			delete[] buff;
+			buff = nullptr;
+			delete[] len;
+			len = nullptr;
+			delete[] flag;
+			flag = nullptr;
+			delete[] start_time;
+			start_time = nullptr;
+		}
+		for(int idx=0; idx<msgs.size(); idx++){
+			string msg = msgs[idx];
 			std::vector<std::string> msg_vec = funasr::split(msg, '|');
 			if(msg_vec.size()==0){
 				continue;
@@ -276,14 +292,11 @@
 			if(msg_vec.size() > 1){
 				std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
 				for(int i=0; i<msg_stamp.size()-1; i+=2){
-					float begin = std::stof(msg_stamp[i])+start_time;
-					float end = std::stof(msg_stamp[i+1])+start_time;
+					float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
+					float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
 					cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
 				}
 			}
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
 		if(cur_stamp != "["){
 			cur_stamp.erase(cur_stamp.length() - 1);
@@ -342,25 +355,53 @@
 		if(p_result->snippet_time == 0){
             return p_result;
         }
+		std::vector<int> index_vector={0};
+		int msg_idx = 0;
 		if(offline_stream->UseVad()){
-			audio.CutSplit(offline_stream);
+			audio.CutSplit(offline_stream, index_vector);
 		}
+		std::vector<string> msgs(index_vector.size());
+		std::vector<float> msg_stimes(index_vector.size());
 
-		float* buff;
-		int len;
-		int flag = 0;
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
-		float start_time = 0.0;
+		float** buff;
+		int* len;
+		int* flag;
+		float* start_time;
+		int batch_size = offline_stream->asr_handle->GetBatchSize();
+		int batch_in = 0;
+
 		std::string cur_stamp = "[";
 		std::string lang = (offline_stream->asr_handle)->GetLang();
-		while (audio.Fetch(buff, len, flag, start_time) > 0) {
+		while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
 			// dec reset
 			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
 			if (wfst_decoder){
 				wfst_decoder->StartUtterance();
 			}
-			string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
+			vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
+			for(int idx=0; idx<batch_in; idx++){
+				string msg = msg_batch[idx];
+				if(msg_idx < index_vector.size()){
+					msgs[index_vector[msg_idx]] = msg;
+					msg_stimes[index_vector[msg_idx]] = start_time[idx];
+					msg_idx++;
+				}else{
+					LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
+				}				
+			}
+
+			// release
+			delete[] buff;
+			buff = nullptr;
+			delete[] len;
+			len = nullptr;
+			delete[] flag;
+			flag = nullptr;
+			delete[] start_time;
+			start_time = nullptr;
+		}
+		for(int idx=0; idx<msgs.size(); idx++){
+			string msg = msgs[idx];
 			std::vector<std::string> msg_vec = funasr::split(msg, '|');
 			if(msg_vec.size()==0){
 				continue;
@@ -373,15 +414,11 @@
 			if(msg_vec.size() > 1){
 				std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
 				for(int i=0; i<msg_stamp.size()-1; i+=2){
-					float begin = std::stof(msg_stamp[i])+start_time;
-					float end = std::stof(msg_stamp[i+1])+start_time;
+					float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
+					float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
 					cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
 				}
 			}
-
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
 		if(cur_stamp != "["){
 			cur_stamp.erase(cur_stamp.length() - 1);
@@ -518,8 +555,14 @@
 			if (wfst_decoder){
 				wfst_decoder->StartUtterance();
 			}
-			string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
-
+			float** buff;
+			int* len;
+			buff = new float*[1];
+        	len = new int[1];
+			buff[0] = frame->data;
+			len[0] = frame->len;
+			vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
+			string msg = msgs.size()>0?msgs[0]:"";
 			std::vector<std::string> msg_vec = funasr::split(msg, '|');  // split with timestamp
 			if(msg_vec.size()==0){
 				continue;
@@ -767,16 +810,45 @@
 		funasr::WfstDecoder* mm = nullptr;
 		if (asr_type == ASR_OFFLINE) {
 			funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
-			funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
-			if (paraformer->lm_)
-				mm = new funasr::WfstDecoder(paraformer->lm_.get(),
-					paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+			auto paraformer = dynamic_cast<funasr::Paraformer*>(offline_stream->asr_handle.get());
+			if(paraformer !=nullptr){
+				if (paraformer->lm_){
+					mm = new funasr::WfstDecoder(paraformer->lm_.get(),
+						paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+				}
+				return mm;
+			}
+			#ifdef USE_GPU
+			auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
+			if(paraformer_torch !=nullptr){
+				if (paraformer_torch->lm_){
+					mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
+						paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
+				}
+				return mm;
+			}
+			#endif
+
 		} else if (asr_type == ASR_TWO_PASS){
 			funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
-			funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
-			if (paraformer->lm_)
-				mm = new funasr::WfstDecoder(paraformer->lm_.get(), 
-					paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+			auto paraformer = dynamic_cast<funasr::Paraformer*>(tpass_stream->asr_handle.get());
+			if(paraformer !=nullptr){
+				if (paraformer->lm_){
+					mm = new funasr::WfstDecoder(paraformer->lm_.get(),
+						paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+				}
+				return mm;
+			}
+			#ifdef USE_GPU
+			auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(tpass_stream->asr_handle.get());
+			if(paraformer_torch !=nullptr){
+				if (paraformer_torch->lm_){
+					mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
+						paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
+				}
+				return mm;
+			}
+			#endif
 		}
 		return mm;
 	}
diff --git a/runtime/onnxruntime/src/offline-stream.cpp b/runtime/onnxruntime/src/offline-stream.cpp
index 7d86f9b..35eb1ba 100644
--- a/runtime/onnxruntime/src/offline-stream.cpp
+++ b/runtime/onnxruntime/src/offline-stream.cpp
@@ -1,7 +1,7 @@
 #include "precomp.h"
 
 namespace funasr {
-OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
+OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
 {
     // VAD model
     if(model_path.find(VAD_DIR) != model_path.end()){
@@ -36,7 +36,19 @@
         string hw_compile_model_path;
         string seg_dict_path;
     
-        asr_handle = make_unique<Paraformer>();
+        if(use_gpu){
+            #ifdef USE_GPU
+            asr_handle = make_unique<ParaformerTorch>();
+            asr_handle->SetBatchSize(batch_size);
+            #else
+            LOG(ERROR) <<"GPU is not supported! CPU will be used! If you want to use GPU, please add -DGPU=ON when cmake";
+            asr_handle = make_unique<Paraformer>();
+            use_gpu = false;
+            #endif
+        }else{
+            asr_handle = make_unique<Paraformer>();
+        }
+
         bool enable_hotword = false;
         hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
         seg_dict_path = PathAppend(model_path.at(MODEL_DIR), MODEL_SEG_DICT);
@@ -54,6 +66,15 @@
           am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
           if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
             am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
+          }
+          if(use_gpu){
+            am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_MODEL_NAME);
+            if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+                am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_QUANT_MODEL_NAME);
+            }
+            if(model_path.find(BLADEDISC) != model_path.end() && model_path.at(BLADEDISC) == "true"){
+                am_model_path = PathAppend(model_path.at(MODEL_DIR), BLADE_MODEL_NAME);
+            }
           }
         }
         am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
@@ -120,10 +141,10 @@
 #endif
 }
 
-OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
+OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
 {
     OfflineStream *mm;
-    mm = new OfflineStream(model_path, thread_num);
+    mm = new OfflineStream(model_path, thread_num, use_gpu, batch_size);
     return mm;
 }
 
diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp
new file mode 100644
index 0000000..113d43f
--- /dev/null
+++ b/runtime/onnxruntime/src/paraformer-torch.cpp
@@ -0,0 +1,415 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
+
+#include "precomp.h"
+#include "paraformer-torch.h"
+#include "encode_converter.h"
+#include <cstddef>
+
+using namespace std;
+namespace funasr {
+
+ParaformerTorch::ParaformerTorch()
+:use_hotword(false){
+}
+
+// offline
+void ParaformerTorch::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+    LoadConfigFromYaml(am_config.c_str());
+    // knf options
+    fbank_opts_.frame_opts.dither = 0;
+    fbank_opts_.mel_opts.num_bins = n_mels;
+    fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
+    fbank_opts_.frame_opts.window_type = window_type;
+    fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
+    fbank_opts_.frame_opts.frame_length_ms = frame_length;
+    fbank_opts_.energy_floor = 0;
+    fbank_opts_.mel_opts.debug_mel = false;
+
+    vocab = new Vocab(am_config.c_str());
+	phone_set_ = new PhoneSet(am_config.c_str());
+    LoadCmvn(am_cmvn.c_str());
+
+    torch::DeviceType device = at::kCPU;
+    #ifdef USE_GPU
+    if (!torch::cuda::is_available()) {
+        LOG(ERROR) << "CUDA is not available! Please check your GPU settings";
+        exit(-1);
+    } else {
+        LOG(INFO) << "CUDA is available, running on GPU";
+        device = at::kCUDA;
+    }
+    #endif
+    #ifdef USE_IPEX
+    torch::jit::setTensorExprFuserEnabled(false);
+    #endif
+
+    try {
+        torch::jit::script::Module model = torch::jit::load(am_model, device);
+        model_ = std::make_shared<TorchModule>(std::move(model)); 
+        LOG(INFO) << "Successfully load model from " << am_model;
+    } catch (std::exception const &e) {
+        LOG(ERROR) << "Error when load am model: " << am_model << e.what();
+        exit(-1);
+    }
+}
+
+void ParaformerTorch::InitLm(const std::string &lm_file, 
+                        const std::string &lm_cfg_file, 
+                        const std::string &lex_file) {
+    try {
+        lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
+            fst::Fst<fst::StdArc>::Read(lm_file));
+        if (lm_){
+            lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
+            LOG(INFO) << "Successfully load lm file " << lm_file;
+        }else{
+            LOG(ERROR) << "Failed to load lm file " << lm_file;
+        }
+    } catch (std::exception const &e) {
+        LOG(ERROR) << "Error when load lm file: " << e.what();
+        exit(0);
+    }
+}
+
+void ParaformerTorch::LoadConfigFromYaml(const char* filename){
+
+    YAML::Node config;
+    try{
+        config = YAML::LoadFile(filename);
+    }catch(exception const &e){
+        LOG(ERROR) << "Error loading file, yaml file error or not exist.";
+        exit(-1);
+    }
+
+    try{
+        YAML::Node frontend_conf = config["frontend_conf"];
+        this->asr_sample_rate = frontend_conf["fs"].as<int>();
+
+        YAML::Node lang_conf = config["lang"];
+        if (lang_conf.IsDefined()){
+            language = lang_conf.as<string>();
+        }
+    }catch(exception const &e){
+        LOG(ERROR) << "Error when load argument from vad config YAML.";
+        exit(-1);
+    }
+}
+
+void ParaformerTorch::InitHwCompiler(const std::string &hw_model, int thread_num) {
+    // TODO
+    use_hotword = true;
+}
+
+void ParaformerTorch::InitSegDict(const std::string &seg_dict_model) {
+    seg_dict = new SegDict(seg_dict_model.c_str());
+}
+
+ParaformerTorch::~ParaformerTorch()
+{
+    if(vocab){
+        delete vocab;
+    }
+    if(lm_vocab){
+        delete lm_vocab;
+    }
+    if(seg_dict){
+        delete seg_dict;
+    }
+    if(phone_set_){
+        delete phone_set_;
+    }
+}
+
+void ParaformerTorch::StartUtterance()
+{
+}
+
+void ParaformerTorch::EndUtterance()
+{
+}
+
+void ParaformerTorch::Reset()
+{
+}
+
+void ParaformerTorch::FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats) {
+    knf::OnlineFbank fbank_(fbank_opts_);
+    std::vector<float> buf(len);
+    for (int32_t i = 0; i != len; ++i) {
+        buf[i] = waves[i] * 32768;
+    }
+    fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
+
+    int32_t frames = fbank_.NumFramesReady();
+    for (int32_t i = 0; i != frames; ++i) {
+        const float *frame = fbank_.GetFrame(i);
+        std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
+        asr_feats.emplace_back(frame_vector);
+    }
+}
+
+void ParaformerTorch::LoadCmvn(const char *filename)
+{
+    ifstream cmvn_stream(filename);
+    if (!cmvn_stream.is_open()) {
+        LOG(ERROR) << "Failed to open file: " << filename;
+        exit(-1);
+    }
+    string line;
+
+    while (getline(cmvn_stream, line)) {
+        istringstream iss(line);
+        vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
+        if (line_item[0] == "<AddShift>") {
+            getline(cmvn_stream, line);
+            istringstream means_lines_stream(line);
+            vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
+            if (means_lines[0] == "<LearnRateCoef>") {
+                for (int j = 3; j < means_lines.size() - 1; j++) {
+                    means_list_.push_back(stof(means_lines[j]));
+                }
+                continue;
+            }
+        }
+        else if (line_item[0] == "<Rescale>") {
+            getline(cmvn_stream, line);
+            istringstream vars_lines_stream(line);
+            vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
+            if (vars_lines[0] == "<LearnRateCoef>") {
+                for (int j = 3; j < vars_lines.size() - 1; j++) {
+                    vars_list_.push_back(stof(vars_lines[j])*scale);
+                }
+                continue;
+            }
+        }
+    }
+}
+
+string ParaformerTorch::GreedySearch(float * in, int n_len,  int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
+{
+    vector<int> hyps;
+    int Tmax = n_len;
+    for (int i = 0; i < Tmax; i++) {
+        int max_idx;
+        float max_val;
+        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
+        hyps.push_back(max_idx);
+    }
+    if(!is_stamp){
+        return vocab->Vector2StringV2(hyps, language);
+    }else{
+        std::vector<string> char_list;
+        std::vector<std::vector<float>> timestamp_list;
+        std::string res_str;
+        vocab->Vector2String(hyps, char_list);
+        std::vector<string> raw_char(char_list);
+        TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list);
+
+        return PostProcess(raw_char, timestamp_list);
+    }
+}
+
+string ParaformerTorch::BeamSearch(WfstDecoder* &wfst_decoder, float *in, int len, int64_t token_nums)
+{
+  return wfst_decoder->Search(in, len, token_nums);
+}
+
+string ParaformerTorch::FinalizeDecode(WfstDecoder* &wfst_decoder,
+                                  bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
+{
+  return wfst_decoder->FinalizeDecode(is_stamp, us_alphas, us_cif_peak);
+}
+
+void ParaformerTorch::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
+
+    std::vector<std::vector<float>> out_feats;
+    int T = asr_feats.size();
+    int T_lrf = ceil(1.0 * T / lfr_n);
+
+    // Pad frames at start(copy first frame)
+    for (int i = 0; i < (lfr_m - 1) / 2; i++) {
+        asr_feats.insert(asr_feats.begin(), asr_feats[0]);
+    }
+    // Merge lfr_m frames as one,lfr_n frames per window
+    T = T + (lfr_m - 1) / 2;
+    std::vector<float> p;
+    for (int i = 0; i < T_lrf; i++) {
+        if (lfr_m <= T - i * lfr_n) {
+            for (int j = 0; j < lfr_m; j++) {
+                p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
+            }
+            out_feats.emplace_back(p);
+            p.clear();
+        } else {
+            // Fill to lfr_m frames at last window if less than lfr_m frames  (copy last frame)
+            int num_padding = lfr_m - (T - i * lfr_n);
+            for (int j = 0; j < (asr_feats.size() - i * lfr_n); j++) {
+                p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
+            }
+            for (int j = 0; j < num_padding; j++) {
+                p.insert(p.end(), asr_feats[asr_feats.size() - 1].begin(), asr_feats[asr_feats.size() - 1].end());
+            }
+            out_feats.emplace_back(p);
+            p.clear();
+        }
+    }
+    // Apply cmvn
+    for (auto &out_feat: out_feats) {
+        for (int j = 0; j < means_list_.size(); j++) {
+            out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
+        }
+    }
+    asr_feats = out_feats;
+}
+
+std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
+{
+    WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
+    int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
+    int32_t feature_dim = lfr_m*in_feat_dim;
+
+    std::vector<vector<float>> feats_batch;
+    std::vector<int32_t> paraformer_length;
+    int max_size = 0;
+    int max_frames = 0;
+    for(int index=0; index<batch_in; index++){
+        std::vector<std::vector<float>> asr_feats;
+        FbankKaldi(asr_sample_rate, din[index], len[index], asr_feats);
+        if(asr_feats.size() != 0){
+            LfrCmvn(asr_feats);
+        }
+        int32_t num_frames  = asr_feats.size();
+        paraformer_length.emplace_back(num_frames);
+        if(max_size < asr_feats.size()*feature_dim){
+            max_size = asr_feats.size()*feature_dim;
+            max_frames = num_frames;
+        }
+
+        std::vector<float> flattened;
+        for (const auto& sub_vector : asr_feats) {
+            flattened.insert(flattened.end(), sub_vector.begin(), sub_vector.end());
+        }
+        feats_batch.emplace_back(flattened);
+    }
+
+    torch::NoGradGuard no_grad;
+    model_->eval();
+    // padding
+    std::vector<float> all_feats(batch_in * max_frames * feature_dim);
+    for(int index=0; index<batch_in; index++){
+        feats_batch[index].resize(max_size);
+        std::memcpy(&all_feats[index * max_frames * feature_dim], feats_batch[index].data(),
+                        max_frames * feature_dim * sizeof(float));
+    }
+    torch::Tensor feats =
+        torch::from_blob(all_feats.data(),
+                {batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
+    torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
+                        {batch_in}, torch::kInt32);
+
+    // 2. forward
+    #ifdef USE_GPU
+    feats = feats.to(at::kCUDA);
+    feat_lens = feat_lens.to(at::kCUDA);
+    #endif
+    std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
+
+    vector<std::string> results;
+    try {
+        auto outputs = model_->forward(inputs).toTuple()->elements();
+        torch::Tensor am_scores;
+        torch::Tensor valid_token_lens;
+        #ifdef USE_GPU
+        am_scores = outputs[0].toTensor().to(at::kCPU);
+        valid_token_lens = outputs[1].toTensor().to(at::kCPU);
+        #else
+        am_scores = outputs[0].toTensor();
+        valid_token_lens = outputs[1].toTensor();
+        #endif
+        // timestamp
+        for(int index=0; index<batch_in; index++){
+            string result="";
+            if(outputs.size() == 4){
+                torch::Tensor us_alphas_tensor;
+                torch::Tensor us_peaks_tensor;
+                #ifdef USE_GPU
+                us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
+                us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
+                #else
+                us_alphas_tensor = outputs[2].toTensor();
+                us_peaks_tensor = outputs[3].toTensor();
+                #endif
+
+                float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
+                std::vector<float> us_alphas(paraformer_length[index]);
+                for (int i = 0; i < us_alphas.size(); i++) {
+                    us_alphas[i] = us_alphas_data[i];
+                }
+
+                float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
+                std::vector<float> us_peaks(paraformer_length[index]);
+                for (int i = 0; i < us_peaks.size(); i++) {
+                    us_peaks[i] = us_peaks_data[i];
+                }
+                if (lm_ == nullptr) {
+                    result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
+                } else {
+                    result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+                    if (input_finished) {
+                        result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
+                    }
+                }
+            }else{
+                if (lm_ == nullptr) {
+                    result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+                } else {
+                    result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+                    if (input_finished) {
+                        result = FinalizeDecode(wfst_decoder);
+                    }
+                }
+            }
+            results.push_back(result);
+			if (wfst_decoder){
+				wfst_decoder->StartUtterance();
+			}
+        }
+    }
+    catch (std::exception const &e)
+    {
+        LOG(ERROR)<<e.what();
+    }
+
+    return results;
+}
+
+std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
+    // TODO
+    std::vector<std::vector<float>> result(1, std::vector<float>(10, 0.0f));
+    return result;
+}
+
+Vocab* ParaformerTorch::GetVocab()
+{
+    return vocab;
+}
+
+Vocab* ParaformerTorch::GetLmVocab()
+{
+    return lm_vocab;
+}
+
+PhoneSet* ParaformerTorch::GetPhoneSet()
+{
+    return phone_set_;
+}
+
+string ParaformerTorch::Rescoring()
+{
+    LOG(ERROR)<<"Not Imp!!!!!!";
+    return "";
+}
+} // namespace funasr
diff --git a/runtime/onnxruntime/src/paraformer-torch.h b/runtime/onnxruntime/src/paraformer-torch.h
new file mode 100644
index 0000000..e49094d
--- /dev/null
+++ b/runtime/onnxruntime/src/paraformer-torch.h
@@ -0,0 +1,96 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
+#pragma once
+#define C10_USE_GLOG
+#include <torch/serialize.h>
+#include <torch/script.h>
+#include <torch/torch.h>
+#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
+#include "precomp.h"
+#include "fst/fstlib.h"
+#include "fst/symbol-table.h"
+#include "bias-lm.h"
+#include "phone-set.h"
+
+namespace funasr {
+
+    class ParaformerTorch : public Model {
+    /**
+     * Author: Speech Lab of DAMO Academy, Alibaba Group
+     * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+     * https://arxiv.org/pdf/2206.08317.pdf
+    */
+    private:
+        Vocab* vocab = nullptr;
+        Vocab* lm_vocab = nullptr;
+        SegDict* seg_dict = nullptr;
+        PhoneSet* phone_set_ = nullptr;
+        //const float scale = 22.6274169979695;
+        const float scale = 1.0;
+
+        void LoadConfigFromYaml(const char* filename);
+        void LoadCmvn(const char *filename);
+        void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
+
+        using TorchModule = torch::jit::script::Module;
+        std::shared_ptr<TorchModule> model_ = nullptr;
+        std::vector<torch::Tensor> encoder_outs_;
+        bool use_hotword;
+
+    public:
+        ParaformerTorch();
+        ~ParaformerTorch();
+        void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+        void InitHwCompiler(const std::string &hw_model, int thread_num);
+        void InitSegDict(const std::string &seg_dict_model);
+        std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
+        void Reset();
+        void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
+        std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
+        string GreedySearch( float* in, int n_len, int64_t token_nums,
+                             bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
+
+        string Rescoring();
+        string GetLang(){return language;};
+        int GetAsrSampleRate() { return asr_sample_rate; };
+        void SetBatchSize(int batch_size) {batch_size_ = batch_size;};
+        int GetBatchSize() {return batch_size_;};
+        void StartUtterance();
+        void EndUtterance();
+        void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
+        string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
+        string FinalizeDecode(WfstDecoder* &wfst_decoder,
+                          bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
+        Vocab* GetVocab();
+        Vocab* GetLmVocab();
+        PhoneSet* GetPhoneSet();
+		
+        knf::FbankOptions fbank_opts_;
+        vector<float> means_list_;
+        vector<float> vars_list_;
+        int lfr_m = PARA_LFR_M;
+        int lfr_n = PARA_LFR_N;
+
+        // paraformer-offline
+        std::string language="zh-cn";
+
+        // lm
+        std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
+
+        string window_type = "hamming";
+        int frame_length = 25;
+        int frame_shift = 10;
+        int n_mels = 80;
+        int encoder_size = 512;
+        int fsmn_layers = 16;
+        int fsmn_lorder = 10;
+        int fsmn_dims = 512;
+        float cif_threshold = 1.0;
+        float tail_alphas = 0.45;
+        int asr_sample_rate = MODEL_SAMPLE_RATE;
+        int batch_size_ = 1;
+    };
+
+} // namespace funasr
diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index a57fb9b..1f1d48f 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -462,15 +462,23 @@
     asr_feats = out_feats;
 }
 
-string Paraformer::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
+std::vector<std::string> Paraformer::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
 {
+    std::vector<std::string> results;
+    string result="";
     WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
     int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
 
+    if(batch_in != 1){
+        results.push_back(result);
+        return results;
+    }
+
     std::vector<std::vector<float>> asr_feats;
-    FbankKaldi(asr_sample_rate, din, len, asr_feats);
+    FbankKaldi(asr_sample_rate, din[0], len[0], asr_feats);
     if(asr_feats.size() == 0){
-      return "";
+        results.push_back(result);
+        return results;
     }
     LfrCmvn(asr_feats);
     int32_t feat_dim = lfr_m*in_feat_dim;
@@ -509,7 +517,8 @@
         if (use_hotword) {
             if(hw_emb.size()<=0){
                 LOG(ERROR) << "hw_emb is null";
-                return "";
+                results.push_back(result);
+                return results;
             }
             //PrintMat(hw_emb, "input_clas_emb");
             const int64_t hotword_shape[3] = {1, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())};
@@ -526,10 +535,10 @@
     }catch (std::exception const &e)
     {
         LOG(ERROR)<<e.what();
-        return "";
+        results.push_back(result);
+        return results;
     }
 
-    string result="";
     try {
         auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
         std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
@@ -577,7 +586,8 @@
         LOG(ERROR)<<e.what();
     }
 
-    return result;
+    results.push_back(result);
+    return results;
 }
 
 
diff --git a/runtime/onnxruntime/src/paraformer.h b/runtime/onnxruntime/src/paraformer.h
index 417c2d7..571b2ba 100644
--- a/runtime/onnxruntime/src/paraformer.h
+++ b/runtime/onnxruntime/src/paraformer.h
@@ -52,13 +52,14 @@
         std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
         void Reset();
         void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
-        string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
+        std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
         string GreedySearch( float* in, int n_len, int64_t token_nums,
                              bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
 
         string Rescoring();
         string GetLang(){return language;};
         int GetAsrSampleRate() { return asr_sample_rate; };
+        int GetBatchSize() {return batch_size_;};
         void StartUtterance();
         void EndUtterance();
         void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@@ -110,6 +111,7 @@
         float cif_threshold = 1.0;
         float tail_alphas = 0.45;
         int asr_sample_rate = MODEL_SAMPLE_RATE;
+        int batch_size_ = 1;
     };
 
 } // namespace funasr
diff --git a/runtime/onnxruntime/src/precomp.h b/runtime/onnxruntime/src/precomp.h
index 776de8e..1a98852 100644
--- a/runtime/onnxruntime/src/precomp.h
+++ b/runtime/onnxruntime/src/precomp.h
@@ -64,6 +64,9 @@
 #include "seg_dict.h"
 #include "resample.h"
 #include "paraformer.h"
+#ifdef USE_GPU
+#include "paraformer-torch.h"
+#endif
 #include "paraformer-online.h"
 #include "offline-stream.h"
 #include "tpass-stream.h"
diff --git a/runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp b/runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp
index cf00e94..e1133d6 100644
--- a/runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp
+++ b/runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp
@@ -70,13 +70,13 @@
   return os;
 }
 
-
+#ifndef USE_GPU
 template<class T1, class T2>
 ostream& operator << (ostream& os, const pair<T1, T2>& pr) {
   os << pr.first << ":" << pr.second ;
   return os;
 }
-
+#endif
 
 template<class T>
 string& operator << (string& str, const T& obj) {
diff --git a/runtime/run_server.sh b/runtime/run_server.sh
index 22907ab..2cb577a 100644
--- a/runtime/run_server.sh
+++ b/runtime/run_server.sh
@@ -1,6 +1,10 @@
 
+TORCH_DIR=$(python3 -c "import torch; import os; print(os.path.dirname(torch.__file__))")
+BLADE_DIR=$(python3 -c "import torch_blade; import os; print(os.path.dirname(torch_blade.__file__))")
+export LD_LIBRARY_PATH=/usr/local/lib:${TORCH_DIR}/lib:${BLADE_DIR}:${LD_LIBRARY_PATH}
+
 download_model_dir="/workspace/models"
-model_dir="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx"
+model_dir="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx"
 vad_dir="damo/speech_fsmn_vad_zh-cn-16k-common-onnx"
 punc_dir="damo/punc_ct-transformer_cn-en-common-vocab471067-large-onnx"
 itn_dir="thuduj12/fst_itn_zh"
@@ -38,5 +42,6 @@
   --port ${port} \
   --certfile  "${certfile}" \
   --keyfile "${keyfile}" \
+  --gpu \
   --hotword "${hotword}" &
 
diff --git a/runtime/websocket/CMakeLists.txt b/runtime/websocket/CMakeLists.txt
index ba6497a..da0e6e7 100644
--- a/runtime/websocket/CMakeLists.txt
+++ b/runtime/websocket/CMakeLists.txt
@@ -8,6 +8,10 @@
 
 option(ENABLE_WEBSOCKET "Whether to build websocket server" ON)
 option(ENABLE_PORTAUDIO "Whether to build portaudio" ON)
+option(ENABLE_GLOG "Whether to build glog" ON)
+option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
+option(BUILD_SHARED_LIBS "Build shared libraries" ON)
+option(GPU "Whether to build with GPU" OFF)
 
 if(WIN32)
   file(REMOVE ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/config.h 
@@ -20,12 +24,16 @@
   set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -fPIC")
 endif()
 
-
-
-
-option(ENABLE_GLOG "Whether to build glog" ON)
-option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
-option(BUILD_SHARED_LIBS "Build shared libraries" ON)
+if(GPU)
+    add_definitions(-DUSE_GPU)
+    set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
+    set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
+    include_directories(${TORCH_DIR}/include)
+    include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
+    link_directories(${TORCH_DIR}/lib)
+    link_directories(${TORCH_BLADE_DIR})
+    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
+endif()
  
 if(ENABLE_WEBSOCKET)
   # cmake_policy(SET CMP0135 NEW)
diff --git a/runtime/websocket/bin/CMakeLists.txt b/runtime/websocket/bin/CMakeLists.txt
index 3d2a6cf..8df8ce0 100644
--- a/runtime/websocket/bin/CMakeLists.txt
+++ b/runtime/websocket/bin/CMakeLists.txt
@@ -1,5 +1,4 @@
 
-
 if(WIN32)
   include_directories(${ONNXRUNTIME_DIR}/include)
   include_directories(${FFMPEG_DIR}/include)
@@ -12,15 +11,14 @@
   SET(RELATION_SOURCE "../../onnxruntime/src/resample.cpp" "../../onnxruntime/src/util.cpp" "../../onnxruntime/src/alignedmem.cpp" "../../onnxruntime/src/encode_converter.cpp")
 endif()
 
-
-
-
-
 add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp" ${RELATION_SOURCE})
 add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp" ${RELATION_SOURCE})
 add_executable(funasr-wss-client "funasr-wss-client.cpp" ${RELATION_SOURCE})
 add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp" "microphone.cpp" ${RELATION_SOURCE})
 
+target_link_options(funasr-wss-server PRIVATE "-Wl,--no-as-needed")
+target_link_options(funasr-wss-server-2pass PRIVATE "-Wl,--no-as-needed")
+
 target_link_libraries(funasr-wss-client PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
 target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY} portaudio)
 target_link_libraries(funasr-wss-server PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
diff --git a/runtime/websocket/bin/funasr-wss-server.cpp b/runtime/websocket/bin/funasr-wss-server.cpp
index 5bb7def..100cf35 100644
--- a/runtime/websocket/bin/funasr-wss-server.cpp
+++ b/runtime/websocket/bin/funasr-wss-server.cpp
@@ -56,6 +56,10 @@
         "true (Default), load the model of model_quant.onnx in model_dir. If set "
         "false, load the model of model.onnx in model_dir",
         false, "true", "string");
+    TCLAP::ValueArg<std::string> bladedisc(
+        "", BLADEDISC, 
+        "true (Default), load the model of bladedisc in model_dir.", 
+        false, "true", "string");
     TCLAP::ValueArg<std::string> vad_dir(
         "", VAD_DIR,
         "default: /workspace/models/vad, the vad model path, which contains model_quant.onnx, vad.yaml, vad.mvn",
@@ -121,6 +125,8 @@
         false, "/workspace/resources/hotwords.txt", "string");
     TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, 
         "the fst hotwords incremental bias", false, 20, "int32_t");
+    TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU, default is false", false);
+    TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
 
     // add file
     cmd.add(hotword);
@@ -135,6 +141,7 @@
     cmd.add(model_dir);
     cmd.add(model_revision);
     cmd.add(quantize);
+    cmd.add(bladedisc);
     cmd.add(vad_dir);
     cmd.add(vad_revision);
     cmd.add(vad_quant);
@@ -151,11 +158,14 @@
     cmd.add(io_thread_num);
     cmd.add(decoder_thread_num);
     cmd.add(model_thread_num);
+    cmd.add(use_gpu);
+    cmd.add(batch_size);
     cmd.parse(argc, argv);
 
     std::map<std::string, std::string> model_path;
     GetValue(model_dir, MODEL_DIR, model_path);
     GetValue(quantize, QUANTIZE, model_path);
+    GetValue(bladedisc, BLADEDISC, model_path);
     GetValue(vad_dir, VAD_DIR, model_path);
     GetValue(vad_quant, VAD_QUANT, model_path);
     GetValue(punc_dir, PUNC_DIR, model_path);
@@ -173,6 +183,8 @@
     global_beam_ = global_beam.getValue();
     lattice_beam_ = lattice_beam.getValue();
     am_scale_ = am_scale.getValue();
+    bool use_gpu_ = use_gpu.getValue();
+    int batch_size_ = batch_size.getValue();
 
     // Download model form Modelscope
     try{
@@ -468,7 +480,7 @@
     WebSocketServer websocket_srv(
         io_decoder, is_ssl, server, wss_server, s_certfile,
         s_keyfile);  // websocket server for asr engine
-    websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
+    websocket_srv.initAsr(model_path, s_model_thread_num, use_gpu_, batch_size_);  // init asr model
 
     LOG(INFO) << "decoder-thread-num: " << s_decoder_thread_num;
     LOG(INFO) << "io-thread-num: " << s_io_thread_num;
diff --git a/runtime/websocket/bin/websocket-server.cpp b/runtime/websocket/bin/websocket-server.cpp
index ed25c95..49d8ead 100644
--- a/runtime/websocket/bin/websocket-server.cpp
+++ b/runtime/websocket/bin/websocket-server.cpp
@@ -402,11 +402,11 @@
 
 // init asr model
 void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
-                              int thread_num) {
+                              int thread_num, bool use_gpu, int batch_size) {
   try {
     // init model with api
 
-    asr_handle = FunOfflineInit(model_path, thread_num);
+    asr_handle = FunOfflineInit(model_path, thread_num, use_gpu, batch_size);
     LOG(INFO) << "model successfully inited";
     
     LOG(INFO) << "initAsr run check_and_clean_connection";
diff --git a/runtime/websocket/bin/websocket-server.h b/runtime/websocket/bin/websocket-server.h
index d18bcab..c1389bf 100644
--- a/runtime/websocket/bin/websocket-server.h
+++ b/runtime/websocket/bin/websocket-server.h
@@ -124,7 +124,7 @@
                   std::string wav_format,
                   FUNASR_DEC_HANDLE& decoder_handle);
 
-  void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
+  void initAsr(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
   void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
   void on_open(websocketpp::connection_hdl hdl);
   void on_close(websocketpp::connection_hdl hdl);

--
Gitblit v1.9.1