From ade08818b7a579aac75182b906a5bd3b8126411c Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期一, 27 五月 2024 15:46:26 +0800
Subject: [PATCH] Merge branch 'dev_batch' into main
---
runtime/onnxruntime/src/precomp.h | 3
runtime/onnxruntime/src/paraformer.h | 4
runtime/websocket/bin/websocket-server.cpp | 4
runtime/websocket/bin/websocket-server.h | 2
runtime/onnxruntime/include/com-define.h | 9
runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp | 4
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp | 15
runtime/run_server.sh | 7
runtime/onnxruntime/src/audio.cpp | 106 +++++
runtime/onnxruntime/src/offline-stream.cpp | 29 +
runtime/onnxruntime/bin/CMakeLists.txt | 10
runtime/onnxruntime/src/paraformer-torch.h | 96 +++++
runtime/onnxruntime/src/paraformer.cpp | 24
runtime/onnxruntime/bin/funasr-onnx-offline.cpp | 40 ++
runtime/onnxruntime/CMakeLists.txt | 12
runtime/onnxruntime/src/funasrruntime.cpp | 178 ++++++--
runtime/onnxruntime/include/funasrruntime.h | 2
runtime/websocket/bin/funasr-wss-server.cpp | 14
runtime/onnxruntime/src/CMakeLists.txt | 11
runtime/websocket/CMakeLists.txt | 20
runtime/onnxruntime/include/model.h | 12
runtime/onnxruntime/include/audio.h | 4
runtime/onnxruntime/include/offline-stream.h | 4
runtime/onnxruntime/src/paraformer-torch.cpp | 415 +++++++++++++++++++++++
runtime/websocket/bin/CMakeLists.txt | 8
25 files changed, 936 insertions(+), 97 deletions(-)
diff --git a/runtime/onnxruntime/CMakeLists.txt b/runtime/onnxruntime/CMakeLists.txt
index d8e623e..cf10f54 100644
--- a/runtime/onnxruntime/CMakeLists.txt
+++ b/runtime/onnxruntime/CMakeLists.txt
@@ -4,6 +4,7 @@
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
+option(GPU "Whether to build with GPU" OFF)
# set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
@@ -49,6 +50,17 @@
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi)
include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
+if(GPU)
+ add_definitions(-DUSE_GPU)
+ set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
+ set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
+ include_directories(${TORCH_DIR}/include)
+ include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
+ link_directories(${TORCH_DIR}/lib)
+ link_directories(${TORCH_BLADE_DIR})
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
+endif()
+
if(ENABLE_GLOG)
include_directories(${PROJECT_SOURCE_DIR}/third_party/glog/src)
set(BUILD_TESTING OFF)
diff --git a/runtime/onnxruntime/bin/CMakeLists.txt b/runtime/onnxruntime/bin/CMakeLists.txt
index c91fbc4..0ca7f1e 100644
--- a/runtime/onnxruntime/bin/CMakeLists.txt
+++ b/runtime/onnxruntime/bin/CMakeLists.txt
@@ -10,33 +10,43 @@
endif()
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-offline PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-offline-vad PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-online-vad PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-online-asr PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-asr PUBLIC funasr)
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-offline-punc PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-online-punc PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-offline-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-2pass PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-2pass PUBLIC funasr)
add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-2pass-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr)
add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-onnx-online-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr)
# include_directories(${FFMPEG_DIR}/include)
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index 5722693..c252bc7 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -52,7 +52,7 @@
std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_handle, nn_hotwords_);
// warm up
- for (size_t i = 0; i < 1; i++)
+ for (size_t i = 0; i < 10; i++)
{
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs, true, decoder_handle);
if(result){
@@ -127,6 +127,7 @@
TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0");
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
+ TCLAP::ValueArg<std::string> bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
@@ -140,11 +141,14 @@
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
- TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
+ TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", false, 1, "int32_t");
TCLAP::ValueArg<std::string> hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 闃块噷宸村反 20)", false, "", "string");
+ TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
+ TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
cmd.add(model_dir);
cmd.add(quantize);
+ cmd.add(bladedisc);
cmd.add(vad_dir);
cmd.add(vad_quant);
cmd.add(punc_dir);
@@ -159,11 +163,14 @@
cmd.add(wav_path);
cmd.add(audio_fs);
cmd.add(thread_num);
+ cmd.add(use_gpu);
+ cmd.add(batch_size);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
+ GetValue(bladedisc, BLADEDISC, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
@@ -175,7 +182,9 @@
struct timeval start, end;
gettimeofday(&start, nullptr);
- FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1);
+ bool use_gpu_ = use_gpu.getValue();
+ int batch_size_ = batch_size.getValue();
+ FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1, use_gpu_, batch_size_);
if (!asr_handle)
{
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
index edb83bd..764c581 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -19,6 +19,7 @@
#include "com-define.h"
#include <unordered_map>
#include "util.h"
+#include "audio.h"
using namespace std;
bool is_target_file(const std::string& filename, const std::string target) {
@@ -44,6 +45,7 @@
TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
+ TCLAP::ValueArg<std::string> bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
@@ -57,9 +59,12 @@
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
TCLAP::ValueArg<std::string> hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 闃块噷宸村反 20)", false, "", "string");
+ TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
+ TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
cmd.add(model_dir);
cmd.add(quantize);
+ cmd.add(bladedisc);
cmd.add(vad_dir);
cmd.add(vad_quant);
cmd.add(punc_dir);
@@ -73,11 +78,14 @@
cmd.add(wav_path);
cmd.add(audio_fs);
cmd.add(hotword);
+ cmd.add(use_gpu);
+ cmd.add(batch_size);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
+ GetValue(bladedisc, BLADEDISC, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
@@ -89,7 +97,9 @@
struct timeval start, end;
gettimeofday(&start, nullptr);
int thread_num = 1;
- FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
+ bool use_gpu_ = use_gpu.getValue();
+ int batch_size_ = batch_size.getValue();
+ FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num, use_gpu_, batch_size_);
if (!asr_hanlde)
{
@@ -156,8 +166,34 @@
for (int i = 0; i < wav_list.size(); i++) {
auto& wav_file = wav_list[i];
auto& wav_id = wav_ids[i];
+
+ // For debug:begin
+ int32_t sampling_rate_ = audio_fs.getValue();
+ funasr::Audio audio(1);
+ if(is_target_file(wav_file.c_str(), "wav")){
+ if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }else if(is_target_file(wav_file.c_str(), "pcm")){
+ if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }else{
+ if (!audio.FfmpegLoad(wav_file.c_str(), true)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }
+ char* speech_buff = audio.GetSpeechChar();
+ int buff_len = audio.GetSpeechLen()*2;
+
gettimeofday(&start, nullptr);
- FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), true, decoder_handle);
+ FUNASR_RESULT result=FunOfflineInferBuffer(asr_hanlde, speech_buff, buff_len, RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), "pcm", true, decoder_handle);
+ // For debug:end
+
+ // FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), true, decoder_handle);
gettimeofday(&end, nullptr);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
diff --git a/runtime/onnxruntime/include/audio.h b/runtime/onnxruntime/include/audio.h
index 9edd9c9..3011050 100644
--- a/runtime/onnxruntime/include/audio.h
+++ b/runtime/onnxruntime/include/audio.h
@@ -83,9 +83,11 @@
int FetchTpass(AudioFrame *&frame);
int Fetch(float *&dout, int &len, int &flag);
int Fetch(float *&dout, int &len, int &flag, float &start_time);
+ int Fetch(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
+ int FetchDynamic(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
void Padding();
void Split(OfflineStream* offline_streamj);
- void CutSplit(OfflineStream* offline_streamj);
+ void CutSplit(OfflineStream* offline_streamj, std::vector<int> &index_vector);
void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
void Split(VadModel* vad_obj, int chunk_len, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
float GetTimeLen();
diff --git a/runtime/onnxruntime/include/com-define.h b/runtime/onnxruntime/include/com-define.h
index d4edd5b..77a7b02 100644
--- a/runtime/onnxruntime/include/com-define.h
+++ b/runtime/onnxruntime/include/com-define.h
@@ -51,6 +51,15 @@
#define QUANT_MODEL_NAME "model_quant.onnx"
#define VAD_CMVN_NAME "am.mvn"
#define VAD_CONFIG_NAME "config.yaml"
+
+// gpu models
+#define INFER_GPU "gpu"
+#define BATCHSIZE "batch-size"
+#define TORCH_MODEL_NAME "model.torchscripts"
+#define TORCH_QUANT_MODEL_NAME "model_quant.torchscripts"
+#define BLADE_MODEL_NAME "model.blade.fp16.pt"
+#define BLADEDISC "bladedisc"
+
#define AM_CMVN_NAME "am.mvn"
#define AM_CONFIG_NAME "config.yaml"
#define LM_CONFIG_NAME "config.yaml"
diff --git a/runtime/onnxruntime/include/funasrruntime.h b/runtime/onnxruntime/include/funasrruntime.h
index cff617f..cc9ba3d 100644
--- a/runtime/onnxruntime/include/funasrruntime.h
+++ b/runtime/onnxruntime/include/funasrruntime.h
@@ -96,7 +96,7 @@
_FUNASRAPI void CTTransformerUninit(FUNASR_HANDLE handle);
//OfflineStream
-_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
+_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
_FUNASRAPI void FunOfflineReset(FUNASR_HANDLE handle, FUNASR_DEC_HANDLE dec_handle=nullptr);
// buffer
_FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len,
diff --git a/runtime/onnxruntime/include/model.h b/runtime/onnxruntime/include/model.h
index f5c4027..1064c4c 100644
--- a/runtime/onnxruntime/include/model.h
+++ b/runtime/onnxruntime/include/model.h
@@ -5,6 +5,10 @@
#include <string>
#include <map>
#include "funasrruntime.h"
+#include "vocab.h"
+#include "phone-set.h"
+#include "fst/fstlib.h"
+#include "fst/symbol-table.h"
namespace funasr {
class Model {
public:
@@ -18,13 +22,19 @@
virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
virtual void InitFstDecoder(){};
virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
+ virtual std::vector<std::string> Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1)
+ {return std::vector<string>();};
virtual std::string Rescoring() = 0;
virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
virtual void InitSegDict(const std::string &seg_dict_model){};
virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
virtual std::string GetLang(){return "";};
virtual int GetAsrSampleRate() = 0;
-
+ virtual void SetBatchSize(int batch_size) {};
+ virtual int GetBatchSize() {return 0;};
+ virtual Vocab* GetVocab() {return nullptr;};
+ virtual Vocab* GetLmVocab() {return nullptr;};
+ virtual PhoneSet* GetPhoneSet() {return nullptr;};
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
diff --git a/runtime/onnxruntime/include/offline-stream.h b/runtime/onnxruntime/include/offline-stream.h
index f63de74..cc0f1c4 100644
--- a/runtime/onnxruntime/include/offline-stream.h
+++ b/runtime/onnxruntime/include/offline-stream.h
@@ -14,7 +14,7 @@
namespace funasr {
class OfflineStream {
public:
- OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
+ OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
~OfflineStream(){};
std::unique_ptr<VadModel> vad_handle= nullptr;
@@ -33,6 +33,6 @@
bool use_itn=false;
};
-OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1);
+OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1, bool use_gpu=false, int batch_size=1);
} // namespace funasr
#endif
diff --git a/runtime/onnxruntime/src/CMakeLists.txt b/runtime/onnxruntime/src/CMakeLists.txt
index 9eac2b6..6d6b66b 100644
--- a/runtime/onnxruntime/src/CMakeLists.txt
+++ b/runtime/onnxruntime/src/CMakeLists.txt
@@ -1,6 +1,11 @@
file(GLOB files1 "*.cpp")
+list(REMOVE_ITEM files1 "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
set(files ${files1})
+
+if(GPU)
+ set(files ${files} "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
+endif()
message("files: "${files})
@@ -25,7 +30,11 @@
include_directories(${FFMPEG_DIR}/include)
endif()
+if(GPU)
+ set(TORCH_DEPS torch torch_cuda torch_cpu c10 c10_cuda torch_blade ral_base_context)
+endif()
+
#message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
include_directories(${CMAKE_SOURCE_DIR}/include)
include_directories(${CMAKE_SOURCE_DIR}/third_party)
-target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
+target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS} ${TORCH_DEPS})
diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp
index 0135ab4..22a9ecd 100644
--- a/runtime/onnxruntime/src/audio.cpp
+++ b/runtime/onnxruntime/src/audio.cpp
@@ -1023,6 +1023,90 @@
}
}
+int Audio::Fetch(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
+{
+ batch_in = std::min((int)frame_queue.size(), batch_size);
+ if (batch_in == 0){
+ return 0;
+ } else{
+ // init
+ dout = new float*[batch_in];
+ len = new int[batch_in];
+ flag = new int[batch_in];
+ start_time = new float[batch_in];
+
+ for(int idx=0; idx < batch_in; idx++){
+ AudioFrame *frame = frame_queue.front();
+ frame_queue.pop();
+
+ start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
+ dout[idx] = speech_data + frame->GetStart();
+ len[idx] = frame->GetLen();
+ delete frame;
+ flag[idx] = S_END;
+ }
+ return 1;
+ }
+}
+
+int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
+{
+ //compute batch size
+ queue<AudioFrame *> frame_batch;
+ int max_acc = 300*1000*seg_sample;
+ int max_sent = 60*1000*seg_sample;
+ int bs_acc = 0;
+ int max_len = 0;
+ int max_batch = 1;
+ #ifdef USE_GPU
+ max_batch = batch_size;
+ #endif
+ max_batch = std::min(max_batch, (int)frame_queue.size());
+
+ for(int idx=0; idx < max_batch; idx++){
+ AudioFrame *frame = frame_queue.front();
+ int length = frame->GetLen();
+ if(length >= max_sent){
+ if(bs_acc == 0){
+ bs_acc++;
+ frame_batch.push(frame);
+ frame_queue.pop();
+ }
+ break;
+ }
+ max_len = std::max(max_len, frame->GetLen());
+ if(max_len*(bs_acc+1) > max_acc){
+ break;
+ }
+ bs_acc++;
+ frame_batch.push(frame);
+ frame_queue.pop();
+ }
+
+ batch_in = (int)frame_batch.size();
+ if (batch_in == 0){
+ return 0;
+ } else{
+ // init
+ dout = new float*[batch_in];
+ len = new int[batch_in];
+ flag = new int[batch_in];
+ start_time = new float[batch_in];
+
+ for(int idx=0; idx < batch_in; idx++){
+ AudioFrame *frame = frame_batch.front();
+ frame_batch.pop();
+
+ start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
+ dout[idx] = speech_data + frame->GetStart();
+ len[idx] = frame->GetLen();
+ delete frame;
+ flag[idx] = S_END;
+ }
+ return 1;
+ }
+}
+
void Audio::Padding()
{
float num_samples = speech_len;
@@ -1085,7 +1169,7 @@
}
}
-void Audio::CutSplit(OfflineStream* offline_stream)
+void Audio::CutSplit(OfflineStream* offline_stream, std::vector<int> &index_vector)
{
std::unique_ptr<VadModel> vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(offline_stream->vad_handle).get());
AudioFrame *frame;
@@ -1112,6 +1196,7 @@
}
int speech_start_i = -1, speech_end_i =-1;
+ std::vector<AudioFrame*> vad_frames;
for(vector<int> vad_segment:vad_segments)
{
if(vad_segment.size() != 2){
@@ -1126,16 +1211,31 @@
}
if(speech_start_i!=-1 && speech_end_i!=-1){
- frame = new AudioFrame();
int start = speech_start_i*seg_sample;
int end = speech_end_i*seg_sample;
+ frame = new AudioFrame(end-start);
frame->SetStart(start);
frame->SetEnd(end);
- frame_queue.push(frame);
+ vad_frames.push_back(frame);
frame = nullptr;
speech_start_i=-1;
speech_end_i=-1;
}
+
+ }
+ // sort
+ {
+ index_vector.clear();
+ index_vector.resize(vad_frames.size());
+ for (int i = 0; i < index_vector.size(); ++i) {
+ index_vector[i] = i;
+ }
+ std::sort(index_vector.begin(), index_vector.end(), [&vad_frames](const int a, const int b) {
+ return vad_frames[a]->len < vad_frames[b]->len;
+ });
+ for (int idx : index_vector) {
+ frame_queue.push(vad_frames[idx]);
+ }
}
}
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 4bc64af..d235e6f 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -33,9 +33,9 @@
return mm;
}
- _FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num)
+ _FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
- funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num);
+ funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu, batch_size);
return mm;
}
@@ -74,16 +74,11 @@
if(p_result->snippet_time == 0){
return p_result;
}
- int n_step = 0;
- int n_total = audio.GetQueueSize();
+
while (audio.Fetch(buff, len, flag) > 0) {
string msg = recog_obj->Forward(buff, len, input_finished);
p_result->msg += msg;
- n_step++;
- if (fn_callback)
- fn_callback(n_step, n_total);
}
-
return p_result;
}
@@ -109,8 +104,6 @@
float* buff;
int len;
int flag = 0;
- int n_step = 0;
- int n_total = audio.GetQueueSize();
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
if(p_result->snippet_time == 0){
@@ -119,11 +112,7 @@
while (audio.Fetch(buff, len, flag) > 0) {
string msg = recog_obj->Forward(buff, len, true);
p_result->msg += msg;
- n_step++;
- if (fn_callback)
- fn_callback(n_step, n_total);
}
-
return p_result;
}
@@ -244,26 +233,53 @@
if(p_result->snippet_time == 0){
return p_result;
}
+ std::vector<int> index_vector={0};
+ int msg_idx = 0;
if(offline_stream->UseVad()){
- audio.CutSplit(offline_stream);
+ audio.CutSplit(offline_stream, index_vector);
}
+ std::vector<string> msgs(index_vector.size());
+ std::vector<float> msg_stimes(index_vector.size());
- float* buff;
- int len;
- int flag = 0;
+ float** buff;
+ int* len;
+ int* flag;
+ float* start_time;
+ int batch_size = offline_stream->asr_handle->GetBatchSize();
+ int batch_in = 0;
- int n_step = 0;
- int n_total = audio.GetQueueSize();
- float start_time = 0.0;
std::string cur_stamp = "[";
std::string lang = (offline_stream->asr_handle)->GetLang();
- while (audio.Fetch(buff, len, flag, start_time) > 0) {
+ while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
// dec reset
funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
if (wfst_decoder){
wfst_decoder->StartUtterance();
}
- string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
+ vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
+ for(int idx=0; idx<batch_in; idx++){
+ string msg = msg_batch[idx];
+ if(msg_idx < index_vector.size()){
+ msgs[index_vector[msg_idx]] = msg;
+ msg_stimes[index_vector[msg_idx]] = start_time[idx];
+ msg_idx++;
+ }else{
+ LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
+ }
+ }
+
+ // release
+ delete[] buff;
+ buff = nullptr;
+ delete[] len;
+ len = nullptr;
+ delete[] flag;
+ flag = nullptr;
+ delete[] start_time;
+ start_time = nullptr;
+ }
+ for(int idx=0; idx<msgs.size(); idx++){
+ string msg = msgs[idx];
std::vector<std::string> msg_vec = funasr::split(msg, '|');
if(msg_vec.size()==0){
continue;
@@ -276,14 +292,11 @@
if(msg_vec.size() > 1){
std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
for(int i=0; i<msg_stamp.size()-1; i+=2){
- float begin = std::stof(msg_stamp[i])+start_time;
- float end = std::stof(msg_stamp[i+1])+start_time;
+ float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
+ float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
}
}
- n_step++;
- if (fn_callback)
- fn_callback(n_step, n_total);
}
if(cur_stamp != "["){
cur_stamp.erase(cur_stamp.length() - 1);
@@ -342,25 +355,53 @@
if(p_result->snippet_time == 0){
return p_result;
}
+ std::vector<int> index_vector={0};
+ int msg_idx = 0;
if(offline_stream->UseVad()){
- audio.CutSplit(offline_stream);
+ audio.CutSplit(offline_stream, index_vector);
}
+ std::vector<string> msgs(index_vector.size());
+ std::vector<float> msg_stimes(index_vector.size());
- float* buff;
- int len;
- int flag = 0;
- int n_step = 0;
- int n_total = audio.GetQueueSize();
- float start_time = 0.0;
+ float** buff;
+ int* len;
+ int* flag;
+ float* start_time;
+ int batch_size = offline_stream->asr_handle->GetBatchSize();
+ int batch_in = 0;
+
std::string cur_stamp = "[";
std::string lang = (offline_stream->asr_handle)->GetLang();
- while (audio.Fetch(buff, len, flag, start_time) > 0) {
+ while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
// dec reset
funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
if (wfst_decoder){
wfst_decoder->StartUtterance();
}
- string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
+ vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
+ for(int idx=0; idx<batch_in; idx++){
+ string msg = msg_batch[idx];
+ if(msg_idx < index_vector.size()){
+ msgs[index_vector[msg_idx]] = msg;
+ msg_stimes[index_vector[msg_idx]] = start_time[idx];
+ msg_idx++;
+ }else{
+ LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
+ }
+ }
+
+ // release
+ delete[] buff;
+ buff = nullptr;
+ delete[] len;
+ len = nullptr;
+ delete[] flag;
+ flag = nullptr;
+ delete[] start_time;
+ start_time = nullptr;
+ }
+ for(int idx=0; idx<msgs.size(); idx++){
+ string msg = msgs[idx];
std::vector<std::string> msg_vec = funasr::split(msg, '|');
if(msg_vec.size()==0){
continue;
@@ -373,15 +414,11 @@
if(msg_vec.size() > 1){
std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
for(int i=0; i<msg_stamp.size()-1; i+=2){
- float begin = std::stof(msg_stamp[i])+start_time;
- float end = std::stof(msg_stamp[i+1])+start_time;
+ float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
+ float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
}
}
-
- n_step++;
- if (fn_callback)
- fn_callback(n_step, n_total);
}
if(cur_stamp != "["){
cur_stamp.erase(cur_stamp.length() - 1);
@@ -518,8 +555,14 @@
if (wfst_decoder){
wfst_decoder->StartUtterance();
}
- string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
-
+ float** buff;
+ int* len;
+ buff = new float*[1];
+ len = new int[1];
+ buff[0] = frame->data;
+ len[0] = frame->len;
+ vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
+ string msg = msgs.size()>0?msgs[0]:"";
std::vector<std::string> msg_vec = funasr::split(msg, '|'); // split with timestamp
if(msg_vec.size()==0){
continue;
@@ -767,16 +810,45 @@
funasr::WfstDecoder* mm = nullptr;
if (asr_type == ASR_OFFLINE) {
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
- funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
- if (paraformer->lm_)
- mm = new funasr::WfstDecoder(paraformer->lm_.get(),
- paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+ auto paraformer = dynamic_cast<funasr::Paraformer*>(offline_stream->asr_handle.get());
+ if(paraformer !=nullptr){
+ if (paraformer->lm_){
+ mm = new funasr::WfstDecoder(paraformer->lm_.get(),
+ paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+ }
+ return mm;
+ }
+ #ifdef USE_GPU
+ auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
+ if(paraformer_torch !=nullptr){
+ if (paraformer_torch->lm_){
+ mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
+ paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
+ }
+ return mm;
+ }
+ #endif
+
} else if (asr_type == ASR_TWO_PASS){
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
- funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
- if (paraformer->lm_)
- mm = new funasr::WfstDecoder(paraformer->lm_.get(),
- paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+ auto paraformer = dynamic_cast<funasr::Paraformer*>(tpass_stream->asr_handle.get());
+ if(paraformer !=nullptr){
+ if (paraformer->lm_){
+ mm = new funasr::WfstDecoder(paraformer->lm_.get(),
+ paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+ }
+ return mm;
+ }
+ #ifdef USE_GPU
+ auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(tpass_stream->asr_handle.get());
+ if(paraformer_torch !=nullptr){
+ if (paraformer_torch->lm_){
+ mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
+ paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
+ }
+ return mm;
+ }
+ #endif
}
return mm;
}
diff --git a/runtime/onnxruntime/src/offline-stream.cpp b/runtime/onnxruntime/src/offline-stream.cpp
index 7d86f9b..35eb1ba 100644
--- a/runtime/onnxruntime/src/offline-stream.cpp
+++ b/runtime/onnxruntime/src/offline-stream.cpp
@@ -1,7 +1,7 @@
#include "precomp.h"
namespace funasr {
-OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
+OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
// VAD model
if(model_path.find(VAD_DIR) != model_path.end()){
@@ -36,7 +36,19 @@
string hw_compile_model_path;
string seg_dict_path;
- asr_handle = make_unique<Paraformer>();
+ if(use_gpu){
+ #ifdef USE_GPU
+ asr_handle = make_unique<ParaformerTorch>();
+ asr_handle->SetBatchSize(batch_size);
+ #else
+ LOG(ERROR) <<"GPU is not supported! CPU will be used! If you want to use GPU, please add -DGPU=ON when cmake";
+ asr_handle = make_unique<Paraformer>();
+ use_gpu = false;
+ #endif
+ }else{
+ asr_handle = make_unique<Paraformer>();
+ }
+
bool enable_hotword = false;
hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
seg_dict_path = PathAppend(model_path.at(MODEL_DIR), MODEL_SEG_DICT);
@@ -54,6 +66,15 @@
am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
+ }
+ if(use_gpu){
+ am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_MODEL_NAME);
+ if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+ am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_QUANT_MODEL_NAME);
+ }
+ if(model_path.find(BLADEDISC) != model_path.end() && model_path.at(BLADEDISC) == "true"){
+ am_model_path = PathAppend(model_path.at(MODEL_DIR), BLADE_MODEL_NAME);
+ }
}
}
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
@@ -120,10 +141,10 @@
#endif
}
-OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
+OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
OfflineStream *mm;
- mm = new OfflineStream(model_path, thread_num);
+ mm = new OfflineStream(model_path, thread_num, use_gpu, batch_size);
return mm;
}
diff --git a/runtime/onnxruntime/src/paraformer-torch.cpp b/runtime/onnxruntime/src/paraformer-torch.cpp
new file mode 100644
index 0000000..113d43f
--- /dev/null
+++ b/runtime/onnxruntime/src/paraformer-torch.cpp
@@ -0,0 +1,415 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#include "precomp.h"
+#include "paraformer-torch.h"
+#include "encode_converter.h"
+#include <cstddef>
+
+using namespace std;
+namespace funasr {
+
+ParaformerTorch::ParaformerTorch()
+:use_hotword(false){
+}
+
+// offline
+void ParaformerTorch::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+ LoadConfigFromYaml(am_config.c_str());
+ // knf options
+ fbank_opts_.frame_opts.dither = 0;
+ fbank_opts_.mel_opts.num_bins = n_mels;
+ fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
+ fbank_opts_.frame_opts.window_type = window_type;
+ fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
+ fbank_opts_.frame_opts.frame_length_ms = frame_length;
+ fbank_opts_.energy_floor = 0;
+ fbank_opts_.mel_opts.debug_mel = false;
+
+ vocab = new Vocab(am_config.c_str());
+ phone_set_ = new PhoneSet(am_config.c_str());
+ LoadCmvn(am_cmvn.c_str());
+
+ torch::DeviceType device = at::kCPU;
+ #ifdef USE_GPU
+ if (!torch::cuda::is_available()) {
+ LOG(ERROR) << "CUDA is not available! Please check your GPU settings";
+ exit(-1);
+ } else {
+ LOG(INFO) << "CUDA is available, running on GPU";
+ device = at::kCUDA;
+ }
+ #endif
+ #ifdef USE_IPEX
+ torch::jit::setTensorExprFuserEnabled(false);
+ #endif
+
+ try {
+ torch::jit::script::Module model = torch::jit::load(am_model, device);
+ model_ = std::make_shared<TorchModule>(std::move(model));
+ LOG(INFO) << "Successfully load model from " << am_model;
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am model: " << am_model << e.what();
+ exit(-1);
+ }
+}
+
+void ParaformerTorch::InitLm(const std::string &lm_file,
+ const std::string &lm_cfg_file,
+ const std::string &lex_file) {
+ try {
+ lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
+ fst::Fst<fst::StdArc>::Read(lm_file));
+ if (lm_){
+ lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
+ LOG(INFO) << "Successfully load lm file " << lm_file;
+ }else{
+ LOG(ERROR) << "Failed to load lm file " << lm_file;
+ }
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load lm file: " << e.what();
+ exit(0);
+ }
+}
+
+void ParaformerTorch::LoadConfigFromYaml(const char* filename){
+
+ YAML::Node config;
+ try{
+ config = YAML::LoadFile(filename);
+ }catch(exception const &e){
+ LOG(ERROR) << "Error loading file, yaml file error or not exist.";
+ exit(-1);
+ }
+
+ try{
+ YAML::Node frontend_conf = config["frontend_conf"];
+ this->asr_sample_rate = frontend_conf["fs"].as<int>();
+
+ YAML::Node lang_conf = config["lang"];
+ if (lang_conf.IsDefined()){
+ language = lang_conf.as<string>();
+ }
+ }catch(exception const &e){
+ LOG(ERROR) << "Error when load argument from vad config YAML.";
+ exit(-1);
+ }
+}
+
+void ParaformerTorch::InitHwCompiler(const std::string &hw_model, int thread_num) {
+ // TODO
+ use_hotword = true;
+}
+
+void ParaformerTorch::InitSegDict(const std::string &seg_dict_model) {
+ seg_dict = new SegDict(seg_dict_model.c_str());
+}
+
+ParaformerTorch::~ParaformerTorch()
+{
+ if(vocab){
+ delete vocab;
+ }
+ if(lm_vocab){
+ delete lm_vocab;
+ }
+ if(seg_dict){
+ delete seg_dict;
+ }
+ if(phone_set_){
+ delete phone_set_;
+ }
+}
+
+void ParaformerTorch::StartUtterance()
+{
+}
+
+void ParaformerTorch::EndUtterance()
+{
+}
+
+void ParaformerTorch::Reset()
+{
+}
+
+void ParaformerTorch::FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats) {
+ knf::OnlineFbank fbank_(fbank_opts_);
+ std::vector<float> buf(len);
+ for (int32_t i = 0; i != len; ++i) {
+ buf[i] = waves[i] * 32768;
+ }
+ fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
+
+ int32_t frames = fbank_.NumFramesReady();
+ for (int32_t i = 0; i != frames; ++i) {
+ const float *frame = fbank_.GetFrame(i);
+ std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
+ asr_feats.emplace_back(frame_vector);
+ }
+}
+
+void ParaformerTorch::LoadCmvn(const char *filename)
+{
+ ifstream cmvn_stream(filename);
+ if (!cmvn_stream.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << filename;
+ exit(-1);
+ }
+ string line;
+
+ while (getline(cmvn_stream, line)) {
+ istringstream iss(line);
+ vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
+ if (line_item[0] == "<AddShift>") {
+ getline(cmvn_stream, line);
+ istringstream means_lines_stream(line);
+ vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
+ if (means_lines[0] == "<LearnRateCoef>") {
+ for (int j = 3; j < means_lines.size() - 1; j++) {
+ means_list_.push_back(stof(means_lines[j]));
+ }
+ continue;
+ }
+ }
+ else if (line_item[0] == "<Rescale>") {
+ getline(cmvn_stream, line);
+ istringstream vars_lines_stream(line);
+ vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
+ if (vars_lines[0] == "<LearnRateCoef>") {
+ for (int j = 3; j < vars_lines.size() - 1; j++) {
+ vars_list_.push_back(stof(vars_lines[j])*scale);
+ }
+ continue;
+ }
+ }
+ }
+}
+
+string ParaformerTorch::GreedySearch(float * in, int n_len, int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
+{
+ vector<int> hyps;
+ int Tmax = n_len;
+ for (int i = 0; i < Tmax; i++) {
+ int max_idx;
+ float max_val;
+ FindMax(in + i * token_nums, token_nums, max_val, max_idx);
+ hyps.push_back(max_idx);
+ }
+ if(!is_stamp){
+ return vocab->Vector2StringV2(hyps, language);
+ }else{
+ std::vector<string> char_list;
+ std::vector<std::vector<float>> timestamp_list;
+ std::string res_str;
+ vocab->Vector2String(hyps, char_list);
+ std::vector<string> raw_char(char_list);
+ TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list);
+
+ return PostProcess(raw_char, timestamp_list);
+ }
+}
+
+string ParaformerTorch::BeamSearch(WfstDecoder* &wfst_decoder, float *in, int len, int64_t token_nums)
+{
+ return wfst_decoder->Search(in, len, token_nums);
+}
+
+string ParaformerTorch::FinalizeDecode(WfstDecoder* &wfst_decoder,
+ bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
+{
+ return wfst_decoder->FinalizeDecode(is_stamp, us_alphas, us_cif_peak);
+}
+
+void ParaformerTorch::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
+
+ std::vector<std::vector<float>> out_feats;
+ int T = asr_feats.size();
+ int T_lrf = ceil(1.0 * T / lfr_n);
+
+ // Pad frames at start(copy first frame)
+ for (int i = 0; i < (lfr_m - 1) / 2; i++) {
+ asr_feats.insert(asr_feats.begin(), asr_feats[0]);
+ }
+ // Merge lfr_m frames as one,lfr_n frames per window
+ T = T + (lfr_m - 1) / 2;
+ std::vector<float> p;
+ for (int i = 0; i < T_lrf; i++) {
+ if (lfr_m <= T - i * lfr_n) {
+ for (int j = 0; j < lfr_m; j++) {
+ p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
+ }
+ out_feats.emplace_back(p);
+ p.clear();
+ } else {
+ // Fill to lfr_m frames at last window if less than lfr_m frames (copy last frame)
+ int num_padding = lfr_m - (T - i * lfr_n);
+ for (int j = 0; j < (asr_feats.size() - i * lfr_n); j++) {
+ p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
+ }
+ for (int j = 0; j < num_padding; j++) {
+ p.insert(p.end(), asr_feats[asr_feats.size() - 1].begin(), asr_feats[asr_feats.size() - 1].end());
+ }
+ out_feats.emplace_back(p);
+ p.clear();
+ }
+ }
+ // Apply cmvn
+ for (auto &out_feat: out_feats) {
+ for (int j = 0; j < means_list_.size(); j++) {
+ out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
+ }
+ }
+ asr_feats = out_feats;
+}
+
+std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
+{
+ WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
+ int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
+ int32_t feature_dim = lfr_m*in_feat_dim;
+
+ std::vector<vector<float>> feats_batch;
+ std::vector<int32_t> paraformer_length;
+ int max_size = 0;
+ int max_frames = 0;
+ for(int index=0; index<batch_in; index++){
+ std::vector<std::vector<float>> asr_feats;
+ FbankKaldi(asr_sample_rate, din[index], len[index], asr_feats);
+ if(asr_feats.size() != 0){
+ LfrCmvn(asr_feats);
+ }
+ int32_t num_frames = asr_feats.size();
+ paraformer_length.emplace_back(num_frames);
+ if(max_size < asr_feats.size()*feature_dim){
+ max_size = asr_feats.size()*feature_dim;
+ max_frames = num_frames;
+ }
+
+ std::vector<float> flattened;
+ for (const auto& sub_vector : asr_feats) {
+ flattened.insert(flattened.end(), sub_vector.begin(), sub_vector.end());
+ }
+ feats_batch.emplace_back(flattened);
+ }
+
+ torch::NoGradGuard no_grad;
+ model_->eval();
+ // padding
+ std::vector<float> all_feats(batch_in * max_frames * feature_dim);
+ for(int index=0; index<batch_in; index++){
+ feats_batch[index].resize(max_size);
+ std::memcpy(&all_feats[index * max_frames * feature_dim], feats_batch[index].data(),
+ max_frames * feature_dim * sizeof(float));
+ }
+ torch::Tensor feats =
+ torch::from_blob(all_feats.data(),
+ {batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
+ torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
+ {batch_in}, torch::kInt32);
+
+ // 2. forward
+ #ifdef USE_GPU
+ feats = feats.to(at::kCUDA);
+ feat_lens = feat_lens.to(at::kCUDA);
+ #endif
+ std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
+
+ vector<std::string> results;
+ try {
+ auto outputs = model_->forward(inputs).toTuple()->elements();
+ torch::Tensor am_scores;
+ torch::Tensor valid_token_lens;
+ #ifdef USE_GPU
+ am_scores = outputs[0].toTensor().to(at::kCPU);
+ valid_token_lens = outputs[1].toTensor().to(at::kCPU);
+ #else
+ am_scores = outputs[0].toTensor();
+ valid_token_lens = outputs[1].toTensor();
+ #endif
+ // timestamp
+ for(int index=0; index<batch_in; index++){
+ string result="";
+ if(outputs.size() == 4){
+ torch::Tensor us_alphas_tensor;
+ torch::Tensor us_peaks_tensor;
+ #ifdef USE_GPU
+ us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
+ us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
+ #else
+ us_alphas_tensor = outputs[2].toTensor();
+ us_peaks_tensor = outputs[3].toTensor();
+ #endif
+
+ float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
+ std::vector<float> us_alphas(paraformer_length[index]);
+ for (int i = 0; i < us_alphas.size(); i++) {
+ us_alphas[i] = us_alphas_data[i];
+ }
+
+ float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
+ std::vector<float> us_peaks(paraformer_length[index]);
+ for (int i = 0; i < us_peaks.size(); i++) {
+ us_peaks[i] = us_peaks_data[i];
+ }
+ if (lm_ == nullptr) {
+ result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
+ } else {
+ result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+ if (input_finished) {
+ result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
+ }
+ }
+ }else{
+ if (lm_ == nullptr) {
+ result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+ } else {
+ result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
+ if (input_finished) {
+ result = FinalizeDecode(wfst_decoder);
+ }
+ }
+ }
+ results.push_back(result);
+ if (wfst_decoder){
+ wfst_decoder->StartUtterance();
+ }
+ }
+ }
+ catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ }
+
+ return results;
+}
+
+std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
+ // TODO
+ std::vector<std::vector<float>> result(1, std::vector<float>(10, 0.0f));
+ return result;
+}
+
+Vocab* ParaformerTorch::GetVocab()
+{
+ return vocab;
+}
+
+Vocab* ParaformerTorch::GetLmVocab()
+{
+ return lm_vocab;
+}
+
+PhoneSet* ParaformerTorch::GetPhoneSet()
+{
+ return phone_set_;
+}
+
+string ParaformerTorch::Rescoring()
+{
+ LOG(ERROR)<<"Not Imp!!!!!!";
+ return "";
+}
+} // namespace funasr
diff --git a/runtime/onnxruntime/src/paraformer-torch.h b/runtime/onnxruntime/src/paraformer-torch.h
new file mode 100644
index 0000000..e49094d
--- /dev/null
+++ b/runtime/onnxruntime/src/paraformer-torch.h
@@ -0,0 +1,96 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+#pragma once
+#define C10_USE_GLOG
+#include <torch/serialize.h>
+#include <torch/script.h>
+#include <torch/torch.h>
+#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
+#include "precomp.h"
+#include "fst/fstlib.h"
+#include "fst/symbol-table.h"
+#include "bias-lm.h"
+#include "phone-set.h"
+
+namespace funasr {
+
+ class ParaformerTorch : public Model {
+ /**
+ * Author: Speech Lab of DAMO Academy, Alibaba Group
+ * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ * https://arxiv.org/pdf/2206.08317.pdf
+ */
+ private:
+ Vocab* vocab = nullptr;
+ Vocab* lm_vocab = nullptr;
+ SegDict* seg_dict = nullptr;
+ PhoneSet* phone_set_ = nullptr;
+ //const float scale = 22.6274169979695;
+ const float scale = 1.0;
+
+ void LoadConfigFromYaml(const char* filename);
+ void LoadCmvn(const char *filename);
+ void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
+
+ using TorchModule = torch::jit::script::Module;
+ std::shared_ptr<TorchModule> model_ = nullptr;
+ std::vector<torch::Tensor> encoder_outs_;
+ bool use_hotword;
+
+ public:
+ ParaformerTorch();
+ ~ParaformerTorch();
+ void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+ void InitHwCompiler(const std::string &hw_model, int thread_num);
+ void InitSegDict(const std::string &seg_dict_model);
+ std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
+ void Reset();
+ void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
+ std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
+ string GreedySearch( float* in, int n_len, int64_t token_nums,
+ bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
+
+ string Rescoring();
+ string GetLang(){return language;};
+ int GetAsrSampleRate() { return asr_sample_rate; };
+ void SetBatchSize(int batch_size) {batch_size_ = batch_size;};
+ int GetBatchSize() {return batch_size_;};
+ void StartUtterance();
+ void EndUtterance();
+ void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
+ string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
+ string FinalizeDecode(WfstDecoder* &wfst_decoder,
+ bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
+ Vocab* GetVocab();
+ Vocab* GetLmVocab();
+ PhoneSet* GetPhoneSet();
+
+ knf::FbankOptions fbank_opts_;
+ vector<float> means_list_;
+ vector<float> vars_list_;
+ int lfr_m = PARA_LFR_M;
+ int lfr_n = PARA_LFR_N;
+
+ // paraformer-offline
+ std::string language="zh-cn";
+
+ // lm
+ std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
+
+ string window_type = "hamming";
+ int frame_length = 25;
+ int frame_shift = 10;
+ int n_mels = 80;
+ int encoder_size = 512;
+ int fsmn_layers = 16;
+ int fsmn_lorder = 10;
+ int fsmn_dims = 512;
+ float cif_threshold = 1.0;
+ float tail_alphas = 0.45;
+ int asr_sample_rate = MODEL_SAMPLE_RATE;
+ int batch_size_ = 1;
+ };
+
+} // namespace funasr
diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index a57fb9b..1f1d48f 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -462,15 +462,23 @@
asr_feats = out_feats;
}
-string Paraformer::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
+std::vector<std::string> Paraformer::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
{
+ std::vector<std::string> results;
+ string result="";
WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
+ if(batch_in != 1){
+ results.push_back(result);
+ return results;
+ }
+
std::vector<std::vector<float>> asr_feats;
- FbankKaldi(asr_sample_rate, din, len, asr_feats);
+ FbankKaldi(asr_sample_rate, din[0], len[0], asr_feats);
if(asr_feats.size() == 0){
- return "";
+ results.push_back(result);
+ return results;
}
LfrCmvn(asr_feats);
int32_t feat_dim = lfr_m*in_feat_dim;
@@ -509,7 +517,8 @@
if (use_hotword) {
if(hw_emb.size()<=0){
LOG(ERROR) << "hw_emb is null";
- return "";
+ results.push_back(result);
+ return results;
}
//PrintMat(hw_emb, "input_clas_emb");
const int64_t hotword_shape[3] = {1, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())};
@@ -526,10 +535,10 @@
}catch (std::exception const &e)
{
LOG(ERROR)<<e.what();
- return "";
+ results.push_back(result);
+ return results;
}
- string result="";
try {
auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
@@ -577,7 +586,8 @@
LOG(ERROR)<<e.what();
}
- return result;
+ results.push_back(result);
+ return results;
}
diff --git a/runtime/onnxruntime/src/paraformer.h b/runtime/onnxruntime/src/paraformer.h
index 417c2d7..571b2ba 100644
--- a/runtime/onnxruntime/src/paraformer.h
+++ b/runtime/onnxruntime/src/paraformer.h
@@ -52,13 +52,14 @@
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
void Reset();
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
- string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
+ std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
string GreedySearch( float* in, int n_len, int64_t token_nums,
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
string Rescoring();
string GetLang(){return language;};
int GetAsrSampleRate() { return asr_sample_rate; };
+ int GetBatchSize() {return batch_size_;};
void StartUtterance();
void EndUtterance();
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@@ -110,6 +111,7 @@
float cif_threshold = 1.0;
float tail_alphas = 0.45;
int asr_sample_rate = MODEL_SAMPLE_RATE;
+ int batch_size_ = 1;
};
} // namespace funasr
diff --git a/runtime/onnxruntime/src/precomp.h b/runtime/onnxruntime/src/precomp.h
index 776de8e..1a98852 100644
--- a/runtime/onnxruntime/src/precomp.h
+++ b/runtime/onnxruntime/src/precomp.h
@@ -64,6 +64,9 @@
#include "seg_dict.h"
#include "resample.h"
#include "paraformer.h"
+#ifdef USE_GPU
+#include "paraformer-torch.h"
+#endif
#include "paraformer-online.h"
#include "offline-stream.h"
#include "tpass-stream.h"
diff --git a/runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp b/runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp
index cf00e94..e1133d6 100644
--- a/runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp
+++ b/runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp
@@ -70,13 +70,13 @@
return os;
}
-
+#ifndef USE_GPU
template<class T1, class T2>
ostream& operator << (ostream& os, const pair<T1, T2>& pr) {
os << pr.first << ":" << pr.second ;
return os;
}
-
+#endif
template<class T>
string& operator << (string& str, const T& obj) {
diff --git a/runtime/run_server.sh b/runtime/run_server.sh
index 22907ab..2cb577a 100644
--- a/runtime/run_server.sh
+++ b/runtime/run_server.sh
@@ -1,6 +1,10 @@
+TORCH_DIR=$(python3 -c "import torch; import os; print(os.path.dirname(torch.__file__))")
+BLADE_DIR=$(python3 -c "import torch_blade; import os; print(os.path.dirname(torch_blade.__file__))")
+export LD_LIBRARY_PATH=/usr/local/lib:${TORCH_DIR}/lib:${BLADE_DIR}:${LD_LIBRARY_PATH}
+
download_model_dir="/workspace/models"
-model_dir="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx"
+model_dir="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx"
vad_dir="damo/speech_fsmn_vad_zh-cn-16k-common-onnx"
punc_dir="damo/punc_ct-transformer_cn-en-common-vocab471067-large-onnx"
itn_dir="thuduj12/fst_itn_zh"
@@ -38,5 +42,6 @@
--port ${port} \
--certfile "${certfile}" \
--keyfile "${keyfile}" \
+ --gpu \
--hotword "${hotword}" &
diff --git a/runtime/websocket/CMakeLists.txt b/runtime/websocket/CMakeLists.txt
index ba6497a..da0e6e7 100644
--- a/runtime/websocket/CMakeLists.txt
+++ b/runtime/websocket/CMakeLists.txt
@@ -8,6 +8,10 @@
option(ENABLE_WEBSOCKET "Whether to build websocket server" ON)
option(ENABLE_PORTAUDIO "Whether to build portaudio" ON)
+option(ENABLE_GLOG "Whether to build glog" ON)
+option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
+option(BUILD_SHARED_LIBS "Build shared libraries" ON)
+option(GPU "Whether to build with GPU" OFF)
if(WIN32)
file(REMOVE ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/config.h
@@ -20,12 +24,16 @@
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -fPIC")
endif()
-
-
-
-option(ENABLE_GLOG "Whether to build glog" ON)
-option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
-option(BUILD_SHARED_LIBS "Build shared libraries" ON)
+if(GPU)
+ add_definitions(-DUSE_GPU)
+ set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
+ set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
+ include_directories(${TORCH_DIR}/include)
+ include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
+ link_directories(${TORCH_DIR}/lib)
+ link_directories(${TORCH_BLADE_DIR})
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
+endif()
if(ENABLE_WEBSOCKET)
# cmake_policy(SET CMP0135 NEW)
diff --git a/runtime/websocket/bin/CMakeLists.txt b/runtime/websocket/bin/CMakeLists.txt
index 3d2a6cf..8df8ce0 100644
--- a/runtime/websocket/bin/CMakeLists.txt
+++ b/runtime/websocket/bin/CMakeLists.txt
@@ -1,5 +1,4 @@
-
if(WIN32)
include_directories(${ONNXRUNTIME_DIR}/include)
include_directories(${FFMPEG_DIR}/include)
@@ -12,15 +11,14 @@
SET(RELATION_SOURCE "../../onnxruntime/src/resample.cpp" "../../onnxruntime/src/util.cpp" "../../onnxruntime/src/alignedmem.cpp" "../../onnxruntime/src/encode_converter.cpp")
endif()
-
-
-
-
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-client "funasr-wss-client.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp" "microphone.cpp" ${RELATION_SOURCE})
+target_link_options(funasr-wss-server PRIVATE "-Wl,--no-as-needed")
+target_link_options(funasr-wss-server-2pass PRIVATE "-Wl,--no-as-needed")
+
target_link_libraries(funasr-wss-client PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY} portaudio)
target_link_libraries(funasr-wss-server PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
diff --git a/runtime/websocket/bin/funasr-wss-server.cpp b/runtime/websocket/bin/funasr-wss-server.cpp
index 5bb7def..100cf35 100644
--- a/runtime/websocket/bin/funasr-wss-server.cpp
+++ b/runtime/websocket/bin/funasr-wss-server.cpp
@@ -56,6 +56,10 @@
"true (Default), load the model of model_quant.onnx in model_dir. If set "
"false, load the model of model.onnx in model_dir",
false, "true", "string");
+ TCLAP::ValueArg<std::string> bladedisc(
+ "", BLADEDISC,
+ "true (Default), load the model of bladedisc in model_dir.",
+ false, "true", "string");
TCLAP::ValueArg<std::string> vad_dir(
"", VAD_DIR,
"default: /workspace/models/vad, the vad model path, which contains model_quant.onnx, vad.yaml, vad.mvn",
@@ -121,6 +125,8 @@
false, "/workspace/resources/hotwords.txt", "string");
TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS,
"the fst hotwords incremental bias", false, 20, "int32_t");
+ TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU, default is false", false);
+ TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
// add file
cmd.add(hotword);
@@ -135,6 +141,7 @@
cmd.add(model_dir);
cmd.add(model_revision);
cmd.add(quantize);
+ cmd.add(bladedisc);
cmd.add(vad_dir);
cmd.add(vad_revision);
cmd.add(vad_quant);
@@ -151,11 +158,14 @@
cmd.add(io_thread_num);
cmd.add(decoder_thread_num);
cmd.add(model_thread_num);
+ cmd.add(use_gpu);
+ cmd.add(batch_size);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
+ GetValue(bladedisc, BLADEDISC, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
@@ -173,6 +183,8 @@
global_beam_ = global_beam.getValue();
lattice_beam_ = lattice_beam.getValue();
am_scale_ = am_scale.getValue();
+ bool use_gpu_ = use_gpu.getValue();
+ int batch_size_ = batch_size.getValue();
// Download model form Modelscope
try{
@@ -468,7 +480,7 @@
WebSocketServer websocket_srv(
io_decoder, is_ssl, server, wss_server, s_certfile,
s_keyfile); // websocket server for asr engine
- websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
+ websocket_srv.initAsr(model_path, s_model_thread_num, use_gpu_, batch_size_); // init asr model
LOG(INFO) << "decoder-thread-num: " << s_decoder_thread_num;
LOG(INFO) << "io-thread-num: " << s_io_thread_num;
diff --git a/runtime/websocket/bin/websocket-server.cpp b/runtime/websocket/bin/websocket-server.cpp
index ed25c95..49d8ead 100644
--- a/runtime/websocket/bin/websocket-server.cpp
+++ b/runtime/websocket/bin/websocket-server.cpp
@@ -402,11 +402,11 @@
// init asr model
void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
- int thread_num) {
+ int thread_num, bool use_gpu, int batch_size) {
try {
// init model with api
- asr_handle = FunOfflineInit(model_path, thread_num);
+ asr_handle = FunOfflineInit(model_path, thread_num, use_gpu, batch_size);
LOG(INFO) << "model successfully inited";
LOG(INFO) << "initAsr run check_and_clean_connection";
diff --git a/runtime/websocket/bin/websocket-server.h b/runtime/websocket/bin/websocket-server.h
index d18bcab..c1389bf 100644
--- a/runtime/websocket/bin/websocket-server.h
+++ b/runtime/websocket/bin/websocket-server.h
@@ -124,7 +124,7 @@
std::string wav_format,
FUNASR_DEC_HANDLE& decoder_handle);
- void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
+ void initAsr(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
void on_open(websocketpp::connection_hdl hdl);
void on_close(websocketpp::connection_hdl hdl);
--
Gitblit v1.9.1