Yabin Li
2024-05-27 ade08818b7a579aac75182b906a5bd3b8126411c
Merge branch 'dev_batch' into main
23个文件已修改
2个文件已添加
1033 ■■■■■ 已修改文件
runtime/onnxruntime/CMakeLists.txt 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/bin/CMakeLists.txt 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp 15 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/bin/funasr-onnx-offline.cpp 40 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/audio.h 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/com-define.h 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/funasrruntime.h 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/model.h 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/offline-stream.h 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/CMakeLists.txt 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/audio.cpp 106 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/funasrruntime.cpp 178 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/offline-stream.cpp 29 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer-torch.cpp 415 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer-torch.h 96 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer.cpp 24 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/paraformer.h 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/precomp.h 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/run_server.sh 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/CMakeLists.txt 20 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/CMakeLists.txt 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/funasr-wss-server.cpp 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/websocket-server.cpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/websocket-server.h 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/CMakeLists.txt
@@ -4,6 +4,7 @@
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(GPU "Whether to build with GPU" OFF)
# set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
@@ -49,6 +50,17 @@
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi)
include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
if(GPU)
    add_definitions(-DUSE_GPU)
    set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
    set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
    include_directories(${TORCH_DIR}/include)
    include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
    link_directories(${TORCH_DIR}/lib)
    link_directories(${TORCH_BLADE_DIR})
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
if(ENABLE_GLOG)
    include_directories(${PROJECT_SOURCE_DIR}/third_party/glog/src)
    set(BUILD_TESTING OFF)
runtime/onnxruntime/bin/CMakeLists.txt
@@ -10,33 +10,43 @@
endif()
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-vad PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-vad PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-asr PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-asr PUBLIC funasr)
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-punc PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-punc PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-2pass PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-2pass PUBLIC funasr)
add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-2pass-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr)
add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr)
# include_directories(${FFMPEG_DIR}/include)
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -52,7 +52,7 @@
    std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_handle, nn_hotwords_);
    
    // warm up
    for (size_t i = 0; i < 1; i++)
    for (size_t i = 0; i < 10; i++)
    {
        FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs, true, decoder_handle);
        if(result){
@@ -127,6 +127,7 @@
    TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
    TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
@@ -140,11 +141,14 @@
    TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
    TCLAP::ValueArg<std::int32_t>   audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
    TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
    TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", false, 1, "int32_t");
    TCLAP::ValueArg<std::string>    hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
    TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
    TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
    cmd.add(model_dir);
    cmd.add(quantize);
    cmd.add(bladedisc);
    cmd.add(vad_dir);
    cmd.add(vad_quant);
    cmd.add(punc_dir);
@@ -159,11 +163,14 @@
    cmd.add(wav_path);
    cmd.add(audio_fs);
    cmd.add(thread_num);
    cmd.add(use_gpu);
    cmd.add(batch_size);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(bladedisc, BLADEDISC, model_path);
    GetValue(vad_dir, VAD_DIR, model_path);
    GetValue(vad_quant, VAD_QUANT, model_path);
    GetValue(punc_dir, PUNC_DIR, model_path);
@@ -175,7 +182,9 @@
    struct timeval start, end;
    gettimeofday(&start, nullptr);
    FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1);
    bool use_gpu_ = use_gpu.getValue();
    int batch_size_ = batch_size.getValue();
    FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1, use_gpu_, batch_size_);
    if (!asr_handle)
    {
runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -19,6 +19,7 @@
#include "com-define.h"
#include <unordered_map>
#include "util.h"
#include "audio.h"
using namespace std;
bool is_target_file(const std::string& filename, const std::string target) {
@@ -44,6 +45,7 @@
    TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
    TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
@@ -57,9 +59,12 @@
    TCLAP::ValueArg<std::string>    wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
    TCLAP::ValueArg<std::int32_t>   audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
    TCLAP::ValueArg<std::string>    hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
    TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
    TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
    cmd.add(model_dir);
    cmd.add(quantize);
    cmd.add(bladedisc);
    cmd.add(vad_dir);
    cmd.add(vad_quant);
    cmd.add(punc_dir);
@@ -73,11 +78,14 @@
    cmd.add(wav_path);
    cmd.add(audio_fs);
    cmd.add(hotword);
    cmd.add(use_gpu);
    cmd.add(batch_size);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(bladedisc, BLADEDISC, model_path);
    GetValue(vad_dir, VAD_DIR, model_path);
    GetValue(vad_quant, VAD_QUANT, model_path);
    GetValue(punc_dir, PUNC_DIR, model_path);
@@ -89,7 +97,9 @@
    struct timeval start, end;
    gettimeofday(&start, nullptr);
    int thread_num = 1;
    FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
    bool use_gpu_ = use_gpu.getValue();
    int batch_size_ = batch_size.getValue();
    FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num, use_gpu_, batch_size_);
    if (!asr_hanlde)
    {
@@ -156,8 +166,34 @@
    for (int i = 0; i < wav_list.size(); i++) {
        auto& wav_file = wav_list[i];
        auto& wav_id = wav_ids[i];
        // For debug:begin
        int32_t sampling_rate_ = audio_fs.getValue();
        funasr::Audio audio(1);
        if(is_target_file(wav_file.c_str(), "wav")){
            if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
                LOG(ERROR)<<"Failed to load "<< wav_file;
                exit(-1);
            }
        }else if(is_target_file(wav_file.c_str(), "pcm")){
            if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
                LOG(ERROR)<<"Failed to load "<< wav_file;
                exit(-1);
            }
        }else{
            if (!audio.FfmpegLoad(wav_file.c_str(), true)){
                LOG(ERROR)<<"Failed to load "<< wav_file;
                exit(-1);
            }
        }
        char* speech_buff = audio.GetSpeechChar();
        int buff_len = audio.GetSpeechLen()*2;
        gettimeofday(&start, nullptr);
        FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), true, decoder_handle);
        FUNASR_RESULT result=FunOfflineInferBuffer(asr_hanlde, speech_buff, buff_len, RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), "pcm", true, decoder_handle);
        // For debug:end
        // FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), true, decoder_handle);
        gettimeofday(&end, nullptr);
        seconds = (end.tv_sec - start.tv_sec);
        taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
runtime/onnxruntime/include/audio.h
@@ -83,9 +83,11 @@
    int FetchTpass(AudioFrame *&frame);
    int Fetch(float *&dout, int &len, int &flag);
    int Fetch(float *&dout, int &len, int &flag, float &start_time);
    int Fetch(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
    int FetchDynamic(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
    void Padding();
    void Split(OfflineStream* offline_streamj);
    void CutSplit(OfflineStream* offline_streamj);
    void CutSplit(OfflineStream* offline_streamj, std::vector<int> &index_vector);
    void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
    void Split(VadModel* vad_obj, int chunk_len, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
    float GetTimeLen();
runtime/onnxruntime/include/com-define.h
@@ -51,6 +51,15 @@
#define QUANT_MODEL_NAME "model_quant.onnx"
#define VAD_CMVN_NAME "am.mvn"
#define VAD_CONFIG_NAME "config.yaml"
// gpu models
#define INFER_GPU "gpu"
#define BATCHSIZE "batch-size"
#define TORCH_MODEL_NAME "model.torchscripts"
#define TORCH_QUANT_MODEL_NAME "model_quant.torchscripts"
#define BLADE_MODEL_NAME "model.blade.fp16.pt"
#define BLADEDISC "bladedisc"
#define AM_CMVN_NAME "am.mvn"
#define AM_CONFIG_NAME "config.yaml"
#define LM_CONFIG_NAME "config.yaml"
runtime/onnxruntime/include/funasrruntime.h
@@ -96,7 +96,7 @@
_FUNASRAPI void                    CTTransformerUninit(FUNASR_HANDLE handle);
//OfflineStream
_FUNASRAPI FUNASR_HANDLE      FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI FUNASR_HANDLE      FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
_FUNASRAPI void             FunOfflineReset(FUNASR_HANDLE handle, FUNASR_DEC_HANDLE dec_handle=nullptr);
// buffer
_FUNASRAPI FUNASR_RESULT    FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, 
runtime/onnxruntime/include/model.h
@@ -5,6 +5,10 @@
#include <string>
#include <map>
#include "funasrruntime.h"
#include "vocab.h"
#include "phone-set.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
namespace funasr {
class Model {
  public:
@@ -18,13 +22,19 @@
    virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
    virtual void InitFstDecoder(){};
    virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
    virtual std::vector<std::string> Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1)
      {return std::vector<string>();};
    virtual std::string Rescoring() = 0;
    virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
    virtual void InitSegDict(const std::string &seg_dict_model){};
    virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
    virtual std::string GetLang(){return "";};
    virtual int GetAsrSampleRate() = 0;
    virtual void SetBatchSize(int batch_size) {};
    virtual int GetBatchSize() {return 0;};
    virtual Vocab* GetVocab() {return nullptr;};
    virtual Vocab* GetLmVocab() {return nullptr;};
    virtual PhoneSet* GetPhoneSet() {return nullptr;};
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
runtime/onnxruntime/include/offline-stream.h
@@ -14,7 +14,7 @@
namespace funasr {
class OfflineStream {
  public:
    OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
    OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
    ~OfflineStream(){};
    std::unique_ptr<VadModel> vad_handle= nullptr;
@@ -33,6 +33,6 @@
    bool use_itn=false;
};
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1);
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1, bool use_gpu=false, int batch_size=1);
} // namespace funasr
#endif
runtime/onnxruntime/src/CMakeLists.txt
@@ -1,6 +1,11 @@
file(GLOB files1 "*.cpp")
list(REMOVE_ITEM files1 "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
set(files ${files1})
if(GPU)
    set(files ${files} "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
endif()
message("files: "${files})
@@ -25,7 +30,11 @@
    include_directories(${FFMPEG_DIR}/include)
endif()
if(GPU)
    set(TORCH_DEPS torch torch_cuda torch_cpu c10 c10_cuda torch_blade ral_base_context)
endif()
#message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
include_directories(${CMAKE_SOURCE_DIR}/include)
include_directories(${CMAKE_SOURCE_DIR}/third_party)
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS} ${TORCH_DEPS})
runtime/onnxruntime/src/audio.cpp
@@ -1023,6 +1023,90 @@
    }
}
int Audio::Fetch(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
{
    batch_in = std::min((int)frame_queue.size(), batch_size);
    if (batch_in == 0){
        return 0;
    } else{
        // init
        dout = new float*[batch_in];
        len = new int[batch_in];
        flag = new int[batch_in];
        start_time = new float[batch_in];
        for(int idx=0; idx < batch_in; idx++){
            AudioFrame *frame = frame_queue.front();
            frame_queue.pop();
            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
            dout[idx] = speech_data + frame->GetStart();
            len[idx] = frame->GetLen();
            delete frame;
            flag[idx] = S_END;
        }
        return 1;
    }
}
int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
{
    //compute batch size
    queue<AudioFrame *> frame_batch;
    int max_acc = 300*1000*seg_sample;
    int max_sent = 60*1000*seg_sample;
    int bs_acc = 0;
    int max_len = 0;
    int max_batch = 1;
    #ifdef USE_GPU
        max_batch = batch_size;
    #endif
    max_batch = std::min(max_batch, (int)frame_queue.size());
    for(int idx=0; idx < max_batch; idx++){
        AudioFrame *frame = frame_queue.front();
        int length = frame->GetLen();
        if(length >= max_sent){
            if(bs_acc == 0){
                bs_acc++;
                frame_batch.push(frame);
                frame_queue.pop();
            }
            break;
        }
        max_len = std::max(max_len, frame->GetLen());
        if(max_len*(bs_acc+1) > max_acc){
            break;
        }
        bs_acc++;
        frame_batch.push(frame);
        frame_queue.pop();
    }
    batch_in = (int)frame_batch.size();
    if (batch_in == 0){
        return 0;
    } else{
        // init
        dout = new float*[batch_in];
        len = new int[batch_in];
        flag = new int[batch_in];
        start_time = new float[batch_in];
        for(int idx=0; idx < batch_in; idx++){
            AudioFrame *frame = frame_batch.front();
            frame_batch.pop();
            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
            dout[idx] = speech_data + frame->GetStart();
            len[idx] = frame->GetLen();
            delete frame;
            flag[idx] = S_END;
        }
        return 1;
    }
}
void Audio::Padding()
{
    float num_samples = speech_len;
@@ -1085,7 +1169,7 @@
    }
}
void Audio::CutSplit(OfflineStream* offline_stream)
void Audio::CutSplit(OfflineStream* offline_stream, std::vector<int> &index_vector)
{
    std::unique_ptr<VadModel> vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(offline_stream->vad_handle).get());
    AudioFrame *frame;
@@ -1112,6 +1196,7 @@
    }    
    int speech_start_i = -1, speech_end_i =-1;
    std::vector<AudioFrame*> vad_frames;
    for(vector<int> vad_segment:vad_segments)
    {
        if(vad_segment.size() != 2){
@@ -1126,16 +1211,31 @@
        }
        if(speech_start_i!=-1 && speech_end_i!=-1){
            frame = new AudioFrame();
            int start = speech_start_i*seg_sample;
            int end = speech_end_i*seg_sample;
            frame = new AudioFrame(end-start);
            frame->SetStart(start);
            frame->SetEnd(end);
            frame_queue.push(frame);
            vad_frames.push_back(frame);
            frame = nullptr;
            speech_start_i=-1;
            speech_end_i=-1;
        }
    }
    // sort
    {
        index_vector.clear();
        index_vector.resize(vad_frames.size());
        for (int i = 0; i < index_vector.size(); ++i) {
            index_vector[i] = i;
        }
        std::sort(index_vector.begin(), index_vector.end(), [&vad_frames](const int a, const int b) {
            return vad_frames[a]->len < vad_frames[b]->len;
        });
        for (int idx : index_vector) {
            frame_queue.push(vad_frames[idx]);
        }
    }
}
runtime/onnxruntime/src/funasrruntime.cpp
@@ -33,9 +33,9 @@
        return mm;
    }
    _FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num)
    _FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
    {
        funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num);
        funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu, batch_size);
        return mm;
    }
@@ -74,16 +74,11 @@
        if(p_result->snippet_time == 0){
            return p_result;
        }
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        while (audio.Fetch(buff, len, flag) > 0) {
            string msg = recog_obj->Forward(buff, len, input_finished);
            p_result->msg += msg;
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        return p_result;
    }
@@ -109,8 +104,6 @@
        float* buff;
        int len;
        int flag = 0;
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
        p_result->snippet_time = audio.GetTimeLen();
        if(p_result->snippet_time == 0){
@@ -119,11 +112,7 @@
        while (audio.Fetch(buff, len, flag) > 0) {
            string msg = recog_obj->Forward(buff, len, true);
            p_result->msg += msg;
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        return p_result;
    }
@@ -244,26 +233,53 @@
        if(p_result->snippet_time == 0){
            return p_result;
        }
        std::vector<int> index_vector={0};
        int msg_idx = 0;
        if(offline_stream->UseVad()){
            audio.CutSplit(offline_stream);
            audio.CutSplit(offline_stream, index_vector);
        }
        std::vector<string> msgs(index_vector.size());
        std::vector<float> msg_stimes(index_vector.size());
        float* buff;
        int len;
        int flag = 0;
        float** buff;
        int* len;
        int* flag;
        float* start_time;
        int batch_size = offline_stream->asr_handle->GetBatchSize();
        int batch_in = 0;
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        float start_time = 0.0;
        std::string cur_stamp = "[";
        std::string lang = (offline_stream->asr_handle)->GetLang();
        while (audio.Fetch(buff, len, flag, start_time) > 0) {
        while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
            // dec reset
            funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
            if (wfst_decoder){
                wfst_decoder->StartUtterance();
            }
            string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
            vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
            for(int idx=0; idx<batch_in; idx++){
                string msg = msg_batch[idx];
                if(msg_idx < index_vector.size()){
                    msgs[index_vector[msg_idx]] = msg;
                    msg_stimes[index_vector[msg_idx]] = start_time[idx];
                    msg_idx++;
                }else{
                    LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
                }
            }
            // release
            delete[] buff;
            buff = nullptr;
            delete[] len;
            len = nullptr;
            delete[] flag;
            flag = nullptr;
            delete[] start_time;
            start_time = nullptr;
        }
        for(int idx=0; idx<msgs.size(); idx++){
            string msg = msgs[idx];
            std::vector<std::string> msg_vec = funasr::split(msg, '|');
            if(msg_vec.size()==0){
                continue;
@@ -276,14 +292,11 @@
            if(msg_vec.size() > 1){
                std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
                for(int i=0; i<msg_stamp.size()-1; i+=2){
                    float begin = std::stof(msg_stamp[i])+start_time;
                    float end = std::stof(msg_stamp[i+1])+start_time;
                    float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
                    float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
                    cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
                }
            }
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        if(cur_stamp != "["){
            cur_stamp.erase(cur_stamp.length() - 1);
@@ -342,25 +355,53 @@
        if(p_result->snippet_time == 0){
            return p_result;
        }
        std::vector<int> index_vector={0};
        int msg_idx = 0;
        if(offline_stream->UseVad()){
            audio.CutSplit(offline_stream);
            audio.CutSplit(offline_stream, index_vector);
        }
        std::vector<string> msgs(index_vector.size());
        std::vector<float> msg_stimes(index_vector.size());
        float* buff;
        int len;
        int flag = 0;
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        float start_time = 0.0;
        float** buff;
        int* len;
        int* flag;
        float* start_time;
        int batch_size = offline_stream->asr_handle->GetBatchSize();
        int batch_in = 0;
        std::string cur_stamp = "[";
        std::string lang = (offline_stream->asr_handle)->GetLang();
        while (audio.Fetch(buff, len, flag, start_time) > 0) {
        while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
            // dec reset
            funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
            if (wfst_decoder){
                wfst_decoder->StartUtterance();
            }
            string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
            vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
            for(int idx=0; idx<batch_in; idx++){
                string msg = msg_batch[idx];
                if(msg_idx < index_vector.size()){
                    msgs[index_vector[msg_idx]] = msg;
                    msg_stimes[index_vector[msg_idx]] = start_time[idx];
                    msg_idx++;
                }else{
                    LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
                }
            }
            // release
            delete[] buff;
            buff = nullptr;
            delete[] len;
            len = nullptr;
            delete[] flag;
            flag = nullptr;
            delete[] start_time;
            start_time = nullptr;
        }
        for(int idx=0; idx<msgs.size(); idx++){
            string msg = msgs[idx];
            std::vector<std::string> msg_vec = funasr::split(msg, '|');
            if(msg_vec.size()==0){
                continue;
@@ -373,15 +414,11 @@
            if(msg_vec.size() > 1){
                std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
                for(int i=0; i<msg_stamp.size()-1; i+=2){
                    float begin = std::stof(msg_stamp[i])+start_time;
                    float end = std::stof(msg_stamp[i+1])+start_time;
                    float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
                    float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
                    cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
                }
            }
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        if(cur_stamp != "["){
            cur_stamp.erase(cur_stamp.length() - 1);
@@ -518,8 +555,14 @@
            if (wfst_decoder){
                wfst_decoder->StartUtterance();
            }
            string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
            float** buff;
            int* len;
            buff = new float*[1];
            len = new int[1];
            buff[0] = frame->data;
            len[0] = frame->len;
            vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
            string msg = msgs.size()>0?msgs[0]:"";
            std::vector<std::string> msg_vec = funasr::split(msg, '|');  // split with timestamp
            if(msg_vec.size()==0){
                continue;
@@ -767,16 +810,45 @@
        funasr::WfstDecoder* mm = nullptr;
        if (asr_type == ASR_OFFLINE) {
            funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
            funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
            if (paraformer->lm_)
                mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                    paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
            auto paraformer = dynamic_cast<funasr::Paraformer*>(offline_stream->asr_handle.get());
            if(paraformer !=nullptr){
                if (paraformer->lm_){
                    mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                        paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #ifdef USE_GPU
            auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
            if(paraformer_torch !=nullptr){
                if (paraformer_torch->lm_){
                    mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
                        paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #endif
        } else if (asr_type == ASR_TWO_PASS){
            funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
            funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
            if (paraformer->lm_)
                mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                    paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
            auto paraformer = dynamic_cast<funasr::Paraformer*>(tpass_stream->asr_handle.get());
            if(paraformer !=nullptr){
                if (paraformer->lm_){
                    mm = new funasr::WfstDecoder(paraformer->lm_.get(),
                        paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #ifdef USE_GPU
            auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(tpass_stream->asr_handle.get());
            if(paraformer_torch !=nullptr){
                if (paraformer_torch->lm_){
                    mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
                        paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
                }
                return mm;
            }
            #endif
        }
        return mm;
    }
runtime/onnxruntime/src/offline-stream.cpp
@@ -1,7 +1,7 @@
#include "precomp.h"
namespace funasr {
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
    // VAD model
    if(model_path.find(VAD_DIR) != model_path.end()){
@@ -36,7 +36,19 @@
        string hw_compile_model_path;
        string seg_dict_path;
    
        asr_handle = make_unique<Paraformer>();
        if(use_gpu){
            #ifdef USE_GPU
            asr_handle = make_unique<ParaformerTorch>();
            asr_handle->SetBatchSize(batch_size);
            #else
            LOG(ERROR) <<"GPU is not supported! CPU will be used! If you want to use GPU, please add -DGPU=ON when cmake";
            asr_handle = make_unique<Paraformer>();
            use_gpu = false;
            #endif
        }else{
            asr_handle = make_unique<Paraformer>();
        }
        bool enable_hotword = false;
        hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
        seg_dict_path = PathAppend(model_path.at(MODEL_DIR), MODEL_SEG_DICT);
@@ -54,6 +66,15 @@
          am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
          if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
            am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
          }
          if(use_gpu){
            am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_MODEL_NAME);
            if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
                am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_QUANT_MODEL_NAME);
            }
            if(model_path.find(BLADEDISC) != model_path.end() && model_path.at(BLADEDISC) == "true"){
                am_model_path = PathAppend(model_path.at(MODEL_DIR), BLADE_MODEL_NAME);
            }
          }
        }
        am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
@@ -120,10 +141,10 @@
#endif
}
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
    OfflineStream *mm;
    mm = new OfflineStream(model_path, thread_num);
    mm = new OfflineStream(model_path, thread_num, use_gpu, batch_size);
    return mm;
}
runtime/onnxruntime/src/paraformer-torch.cpp
New file
@@ -0,0 +1,415 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#include "precomp.h"
#include "paraformer-torch.h"
#include "encode_converter.h"
#include <cstddef>
using namespace std;
namespace funasr {
ParaformerTorch::ParaformerTorch()
:use_hotword(false){
}
// offline
void ParaformerTorch::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
    LoadConfigFromYaml(am_config.c_str());
    // knf options
    fbank_opts_.frame_opts.dither = 0;
    fbank_opts_.mel_opts.num_bins = n_mels;
    fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
    fbank_opts_.frame_opts.window_type = window_type;
    fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
    fbank_opts_.frame_opts.frame_length_ms = frame_length;
    fbank_opts_.energy_floor = 0;
    fbank_opts_.mel_opts.debug_mel = false;
    vocab = new Vocab(am_config.c_str());
    phone_set_ = new PhoneSet(am_config.c_str());
    LoadCmvn(am_cmvn.c_str());
    torch::DeviceType device = at::kCPU;
    #ifdef USE_GPU
    if (!torch::cuda::is_available()) {
        LOG(ERROR) << "CUDA is not available! Please check your GPU settings";
        exit(-1);
    } else {
        LOG(INFO) << "CUDA is available, running on GPU";
        device = at::kCUDA;
    }
    #endif
    #ifdef USE_IPEX
    torch::jit::setTensorExprFuserEnabled(false);
    #endif
    try {
        torch::jit::script::Module model = torch::jit::load(am_model, device);
        model_ = std::make_shared<TorchModule>(std::move(model));
        LOG(INFO) << "Successfully load model from " << am_model;
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load am model: " << am_model << e.what();
        exit(-1);
    }
}
void ParaformerTorch::InitLm(const std::string &lm_file,
                        const std::string &lm_cfg_file,
                        const std::string &lex_file) {
    try {
        lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
            fst::Fst<fst::StdArc>::Read(lm_file));
        if (lm_){
            lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
            LOG(INFO) << "Successfully load lm file " << lm_file;
        }else{
            LOG(ERROR) << "Failed to load lm file " << lm_file;
        }
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load lm file: " << e.what();
        exit(0);
    }
}
void ParaformerTorch::LoadConfigFromYaml(const char* filename){
    YAML::Node config;
    try{
        config = YAML::LoadFile(filename);
    }catch(exception const &e){
        LOG(ERROR) << "Error loading file, yaml file error or not exist.";
        exit(-1);
    }
    try{
        YAML::Node frontend_conf = config["frontend_conf"];
        this->asr_sample_rate = frontend_conf["fs"].as<int>();
        YAML::Node lang_conf = config["lang"];
        if (lang_conf.IsDefined()){
            language = lang_conf.as<string>();
        }
    }catch(exception const &e){
        LOG(ERROR) << "Error when load argument from vad config YAML.";
        exit(-1);
    }
}
void ParaformerTorch::InitHwCompiler(const std::string &hw_model, int thread_num) {
    // TODO
    use_hotword = true;
}
void ParaformerTorch::InitSegDict(const std::string &seg_dict_model) {
    seg_dict = new SegDict(seg_dict_model.c_str());
}
ParaformerTorch::~ParaformerTorch()
{
    if(vocab){
        delete vocab;
    }
    if(lm_vocab){
        delete lm_vocab;
    }
    if(seg_dict){
        delete seg_dict;
    }
    if(phone_set_){
        delete phone_set_;
    }
}
void ParaformerTorch::StartUtterance()
{
}
void ParaformerTorch::EndUtterance()
{
}
void ParaformerTorch::Reset()
{
}
void ParaformerTorch::FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats) {
    knf::OnlineFbank fbank_(fbank_opts_);
    std::vector<float> buf(len);
    for (int32_t i = 0; i != len; ++i) {
        buf[i] = waves[i] * 32768;
    }
    fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
    int32_t frames = fbank_.NumFramesReady();
    for (int32_t i = 0; i != frames; ++i) {
        const float *frame = fbank_.GetFrame(i);
        std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
        asr_feats.emplace_back(frame_vector);
    }
}
void ParaformerTorch::LoadCmvn(const char *filename)
{
    ifstream cmvn_stream(filename);
    if (!cmvn_stream.is_open()) {
        LOG(ERROR) << "Failed to open file: " << filename;
        exit(-1);
    }
    string line;
    while (getline(cmvn_stream, line)) {
        istringstream iss(line);
        vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
        if (line_item[0] == "<AddShift>") {
            getline(cmvn_stream, line);
            istringstream means_lines_stream(line);
            vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
            if (means_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < means_lines.size() - 1; j++) {
                    means_list_.push_back(stof(means_lines[j]));
                }
                continue;
            }
        }
        else if (line_item[0] == "<Rescale>") {
            getline(cmvn_stream, line);
            istringstream vars_lines_stream(line);
            vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
            if (vars_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < vars_lines.size() - 1; j++) {
                    vars_list_.push_back(stof(vars_lines[j])*scale);
                }
                continue;
            }
        }
    }
}
string ParaformerTorch::GreedySearch(float * in, int n_len,  int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
{
    vector<int> hyps;
    int Tmax = n_len;
    for (int i = 0; i < Tmax; i++) {
        int max_idx;
        float max_val;
        FindMax(in + i * token_nums, token_nums, max_val, max_idx);
        hyps.push_back(max_idx);
    }
    if(!is_stamp){
        return vocab->Vector2StringV2(hyps, language);
    }else{
        std::vector<string> char_list;
        std::vector<std::vector<float>> timestamp_list;
        std::string res_str;
        vocab->Vector2String(hyps, char_list);
        std::vector<string> raw_char(char_list);
        TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list);
        return PostProcess(raw_char, timestamp_list);
    }
}
string ParaformerTorch::BeamSearch(WfstDecoder* &wfst_decoder, float *in, int len, int64_t token_nums)
{
  return wfst_decoder->Search(in, len, token_nums);
}
string ParaformerTorch::FinalizeDecode(WfstDecoder* &wfst_decoder,
                                  bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
{
  return wfst_decoder->FinalizeDecode(is_stamp, us_alphas, us_cif_peak);
}
void ParaformerTorch::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
    std::vector<std::vector<float>> out_feats;
    int T = asr_feats.size();
    int T_lrf = ceil(1.0 * T / lfr_n);
    // Pad frames at start(copy first frame)
    for (int i = 0; i < (lfr_m - 1) / 2; i++) {
        asr_feats.insert(asr_feats.begin(), asr_feats[0]);
    }
    // Merge lfr_m frames as one,lfr_n frames per window
    T = T + (lfr_m - 1) / 2;
    std::vector<float> p;
    for (int i = 0; i < T_lrf; i++) {
        if (lfr_m <= T - i * lfr_n) {
            for (int j = 0; j < lfr_m; j++) {
                p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
            }
            out_feats.emplace_back(p);
            p.clear();
        } else {
            // Fill to lfr_m frames at last window if less than lfr_m frames  (copy last frame)
            int num_padding = lfr_m - (T - i * lfr_n);
            for (int j = 0; j < (asr_feats.size() - i * lfr_n); j++) {
                p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
            }
            for (int j = 0; j < num_padding; j++) {
                p.insert(p.end(), asr_feats[asr_feats.size() - 1].begin(), asr_feats[asr_feats.size() - 1].end());
            }
            out_feats.emplace_back(p);
            p.clear();
        }
    }
    // Apply cmvn
    for (auto &out_feat: out_feats) {
        for (int j = 0; j < means_list_.size(); j++) {
            out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
        }
    }
    asr_feats = out_feats;
}
std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
{
    WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
    int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
    int32_t feature_dim = lfr_m*in_feat_dim;
    std::vector<vector<float>> feats_batch;
    std::vector<int32_t> paraformer_length;
    int max_size = 0;
    int max_frames = 0;
    for(int index=0; index<batch_in; index++){
        std::vector<std::vector<float>> asr_feats;
        FbankKaldi(asr_sample_rate, din[index], len[index], asr_feats);
        if(asr_feats.size() != 0){
            LfrCmvn(asr_feats);
        }
        int32_t num_frames  = asr_feats.size();
        paraformer_length.emplace_back(num_frames);
        if(max_size < asr_feats.size()*feature_dim){
            max_size = asr_feats.size()*feature_dim;
            max_frames = num_frames;
        }
        std::vector<float> flattened;
        for (const auto& sub_vector : asr_feats) {
            flattened.insert(flattened.end(), sub_vector.begin(), sub_vector.end());
        }
        feats_batch.emplace_back(flattened);
    }
    torch::NoGradGuard no_grad;
    model_->eval();
    // padding
    std::vector<float> all_feats(batch_in * max_frames * feature_dim);
    for(int index=0; index<batch_in; index++){
        feats_batch[index].resize(max_size);
        std::memcpy(&all_feats[index * max_frames * feature_dim], feats_batch[index].data(),
                        max_frames * feature_dim * sizeof(float));
    }
    torch::Tensor feats =
        torch::from_blob(all_feats.data(),
                {batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
    torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
                        {batch_in}, torch::kInt32);
    // 2. forward
    #ifdef USE_GPU
    feats = feats.to(at::kCUDA);
    feat_lens = feat_lens.to(at::kCUDA);
    #endif
    std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
    vector<std::string> results;
    try {
        auto outputs = model_->forward(inputs).toTuple()->elements();
        torch::Tensor am_scores;
        torch::Tensor valid_token_lens;
        #ifdef USE_GPU
        am_scores = outputs[0].toTensor().to(at::kCPU);
        valid_token_lens = outputs[1].toTensor().to(at::kCPU);
        #else
        am_scores = outputs[0].toTensor();
        valid_token_lens = outputs[1].toTensor();
        #endif
        // timestamp
        for(int index=0; index<batch_in; index++){
            string result="";
            if(outputs.size() == 4){
                torch::Tensor us_alphas_tensor;
                torch::Tensor us_peaks_tensor;
                #ifdef USE_GPU
                us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
                us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
                #else
                us_alphas_tensor = outputs[2].toTensor();
                us_peaks_tensor = outputs[3].toTensor();
                #endif
                float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
                std::vector<float> us_alphas(paraformer_length[index]);
                for (int i = 0; i < us_alphas.size(); i++) {
                    us_alphas[i] = us_alphas_data[i];
                }
                float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
                std::vector<float> us_peaks(paraformer_length[index]);
                for (int i = 0; i < us_peaks.size(); i++) {
                    us_peaks[i] = us_peaks_data[i];
                }
                if (lm_ == nullptr) {
                    result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
                } else {
                    result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
                    if (input_finished) {
                        result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
                    }
                }
            }else{
                if (lm_ == nullptr) {
                    result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
                } else {
                    result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
                    if (input_finished) {
                        result = FinalizeDecode(wfst_decoder);
                    }
                }
            }
            results.push_back(result);
            if (wfst_decoder){
                wfst_decoder->StartUtterance();
            }
        }
    }
    catch (std::exception const &e)
    {
        LOG(ERROR)<<e.what();
    }
    return results;
}
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
    // TODO
    std::vector<std::vector<float>> result(1, std::vector<float>(10, 0.0f));
    return result;
}
Vocab* ParaformerTorch::GetVocab()
{
    return vocab;
}
Vocab* ParaformerTorch::GetLmVocab()
{
    return lm_vocab;
}
PhoneSet* ParaformerTorch::GetPhoneSet()
{
    return phone_set_;
}
string ParaformerTorch::Rescoring()
{
    LOG(ERROR)<<"Not Imp!!!!!!";
    return "";
}
} // namespace funasr
runtime/onnxruntime/src/paraformer-torch.h
New file
@@ -0,0 +1,96 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#pragma once
#define C10_USE_GLOG
#include <torch/serialize.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include "precomp.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "bias-lm.h"
#include "phone-set.h"
namespace funasr {
    class ParaformerTorch : public Model {
    /**
     * Author: Speech Lab of DAMO Academy, Alibaba Group
     * Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     * https://arxiv.org/pdf/2206.08317.pdf
    */
    private:
        Vocab* vocab = nullptr;
        Vocab* lm_vocab = nullptr;
        SegDict* seg_dict = nullptr;
        PhoneSet* phone_set_ = nullptr;
        //const float scale = 22.6274169979695;
        const float scale = 1.0;
        void LoadConfigFromYaml(const char* filename);
        void LoadCmvn(const char *filename);
        void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
        using TorchModule = torch::jit::script::Module;
        std::shared_ptr<TorchModule> model_ = nullptr;
        std::vector<torch::Tensor> encoder_outs_;
        bool use_hotword;
    public:
        ParaformerTorch();
        ~ParaformerTorch();
        void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
        void InitHwCompiler(const std::string &hw_model, int thread_num);
        void InitSegDict(const std::string &seg_dict_model);
        std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
        void Reset();
        void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
        std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
        string GreedySearch( float* in, int n_len, int64_t token_nums,
                             bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
        string Rescoring();
        string GetLang(){return language;};
        int GetAsrSampleRate() { return asr_sample_rate; };
        void SetBatchSize(int batch_size) {batch_size_ = batch_size;};
        int GetBatchSize() {return batch_size_;};
        void StartUtterance();
        void EndUtterance();
        void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
        string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
        string FinalizeDecode(WfstDecoder* &wfst_decoder,
                          bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
        Vocab* GetVocab();
        Vocab* GetLmVocab();
        PhoneSet* GetPhoneSet();
        knf::FbankOptions fbank_opts_;
        vector<float> means_list_;
        vector<float> vars_list_;
        int lfr_m = PARA_LFR_M;
        int lfr_n = PARA_LFR_N;
        // paraformer-offline
        std::string language="zh-cn";
        // lm
        std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
        string window_type = "hamming";
        int frame_length = 25;
        int frame_shift = 10;
        int n_mels = 80;
        int encoder_size = 512;
        int fsmn_layers = 16;
        int fsmn_lorder = 10;
        int fsmn_dims = 512;
        float cif_threshold = 1.0;
        float tail_alphas = 0.45;
        int asr_sample_rate = MODEL_SAMPLE_RATE;
        int batch_size_ = 1;
    };
} // namespace funasr
runtime/onnxruntime/src/paraformer.cpp
@@ -462,15 +462,23 @@
    asr_feats = out_feats;
}
string Paraformer::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
std::vector<std::string> Paraformer::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
{
    std::vector<std::string> results;
    string result="";
    WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
    int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
    if(batch_in != 1){
        results.push_back(result);
        return results;
    }
    std::vector<std::vector<float>> asr_feats;
    FbankKaldi(asr_sample_rate, din, len, asr_feats);
    FbankKaldi(asr_sample_rate, din[0], len[0], asr_feats);
    if(asr_feats.size() == 0){
      return "";
        results.push_back(result);
        return results;
    }
    LfrCmvn(asr_feats);
    int32_t feat_dim = lfr_m*in_feat_dim;
@@ -509,7 +517,8 @@
        if (use_hotword) {
            if(hw_emb.size()<=0){
                LOG(ERROR) << "hw_emb is null";
                return "";
                results.push_back(result);
                return results;
            }
            //PrintMat(hw_emb, "input_clas_emb");
            const int64_t hotword_shape[3] = {1, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())};
@@ -526,10 +535,10 @@
    }catch (std::exception const &e)
    {
        LOG(ERROR)<<e.what();
        return "";
        results.push_back(result);
        return results;
    }
    string result="";
    try {
        auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
        std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
@@ -577,7 +586,8 @@
        LOG(ERROR)<<e.what();
    }
    return result;
    results.push_back(result);
    return results;
}
runtime/onnxruntime/src/paraformer.h
@@ -52,13 +52,14 @@
        std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
        void Reset();
        void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
        string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
        std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
        string GreedySearch( float* in, int n_len, int64_t token_nums,
                             bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
        string Rescoring();
        string GetLang(){return language;};
        int GetAsrSampleRate() { return asr_sample_rate; };
        int GetBatchSize() {return batch_size_;};
        void StartUtterance();
        void EndUtterance();
        void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@@ -110,6 +111,7 @@
        float cif_threshold = 1.0;
        float tail_alphas = 0.45;
        int asr_sample_rate = MODEL_SAMPLE_RATE;
        int batch_size_ = 1;
    };
} // namespace funasr
runtime/onnxruntime/src/precomp.h
@@ -64,6 +64,9 @@
#include "seg_dict.h"
#include "resample.h"
#include "paraformer.h"
#ifdef USE_GPU
#include "paraformer-torch.h"
#endif
#include "paraformer-online.h"
#include "offline-stream.h"
#include "tpass-stream.h"
runtime/onnxruntime/third_party/jieba/include/limonp/StdExtension.hpp
@@ -70,13 +70,13 @@
  return os;
}
#ifndef USE_GPU
template<class T1, class T2>
ostream& operator << (ostream& os, const pair<T1, T2>& pr) {
  os << pr.first << ":" << pr.second ;
  return os;
}
#endif
template<class T>
string& operator << (string& str, const T& obj) {
runtime/run_server.sh
@@ -1,6 +1,10 @@
TORCH_DIR=$(python3 -c "import torch; import os; print(os.path.dirname(torch.__file__))")
BLADE_DIR=$(python3 -c "import torch_blade; import os; print(os.path.dirname(torch_blade.__file__))")
export LD_LIBRARY_PATH=/usr/local/lib:${TORCH_DIR}/lib:${BLADE_DIR}:${LD_LIBRARY_PATH}
download_model_dir="/workspace/models"
model_dir="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx"
model_dir="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx"
vad_dir="damo/speech_fsmn_vad_zh-cn-16k-common-onnx"
punc_dir="damo/punc_ct-transformer_cn-en-common-vocab471067-large-onnx"
itn_dir="thuduj12/fst_itn_zh"
@@ -38,5 +42,6 @@
  --port ${port} \
  --certfile  "${certfile}" \
  --keyfile "${keyfile}" \
  --gpu \
  --hotword "${hotword}" &
runtime/websocket/CMakeLists.txt
@@ -8,6 +8,10 @@
option(ENABLE_WEBSOCKET "Whether to build websocket server" ON)
option(ENABLE_PORTAUDIO "Whether to build portaudio" ON)
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(GPU "Whether to build with GPU" OFF)
if(WIN32)
  file(REMOVE ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/config.h 
@@ -20,12 +24,16 @@
  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -fPIC")
endif()
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
if(GPU)
    add_definitions(-DUSE_GPU)
    set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
    set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
    include_directories(${TORCH_DIR}/include)
    include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
    link_directories(${TORCH_DIR}/lib)
    link_directories(${TORCH_BLADE_DIR})
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
 
if(ENABLE_WEBSOCKET)
  # cmake_policy(SET CMP0135 NEW)
runtime/websocket/bin/CMakeLists.txt
@@ -1,5 +1,4 @@
if(WIN32)
  include_directories(${ONNXRUNTIME_DIR}/include)
  include_directories(${FFMPEG_DIR}/include)
@@ -12,15 +11,14 @@
  SET(RELATION_SOURCE "../../onnxruntime/src/resample.cpp" "../../onnxruntime/src/util.cpp" "../../onnxruntime/src/alignedmem.cpp" "../../onnxruntime/src/encode_converter.cpp")
endif()
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-client "funasr-wss-client.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp" "microphone.cpp" ${RELATION_SOURCE})
target_link_options(funasr-wss-server PRIVATE "-Wl,--no-as-needed")
target_link_options(funasr-wss-server-2pass PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-wss-client PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY} portaudio)
target_link_libraries(funasr-wss-server PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
runtime/websocket/bin/funasr-wss-server.cpp
@@ -56,6 +56,10 @@
        "true (Default), load the model of model_quant.onnx in model_dir. If set "
        "false, load the model of model.onnx in model_dir",
        false, "true", "string");
    TCLAP::ValueArg<std::string> bladedisc(
        "", BLADEDISC,
        "true (Default), load the model of bladedisc in model_dir.",
        false, "true", "string");
    TCLAP::ValueArg<std::string> vad_dir(
        "", VAD_DIR,
        "default: /workspace/models/vad, the vad model path, which contains model_quant.onnx, vad.yaml, vad.mvn",
@@ -121,6 +125,8 @@
        false, "/workspace/resources/hotwords.txt", "string");
    TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, 
        "the fst hotwords incremental bias", false, 20, "int32_t");
    TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU, default is false", false);
    TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
    // add file
    cmd.add(hotword);
@@ -135,6 +141,7 @@
    cmd.add(model_dir);
    cmd.add(model_revision);
    cmd.add(quantize);
    cmd.add(bladedisc);
    cmd.add(vad_dir);
    cmd.add(vad_revision);
    cmd.add(vad_quant);
@@ -151,11 +158,14 @@
    cmd.add(io_thread_num);
    cmd.add(decoder_thread_num);
    cmd.add(model_thread_num);
    cmd.add(use_gpu);
    cmd.add(batch_size);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(bladedisc, BLADEDISC, model_path);
    GetValue(vad_dir, VAD_DIR, model_path);
    GetValue(vad_quant, VAD_QUANT, model_path);
    GetValue(punc_dir, PUNC_DIR, model_path);
@@ -173,6 +183,8 @@
    global_beam_ = global_beam.getValue();
    lattice_beam_ = lattice_beam.getValue();
    am_scale_ = am_scale.getValue();
    bool use_gpu_ = use_gpu.getValue();
    int batch_size_ = batch_size.getValue();
    // Download model form Modelscope
    try{
@@ -468,7 +480,7 @@
    WebSocketServer websocket_srv(
        io_decoder, is_ssl, server, wss_server, s_certfile,
        s_keyfile);  // websocket server for asr engine
    websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
    websocket_srv.initAsr(model_path, s_model_thread_num, use_gpu_, batch_size_);  // init asr model
    LOG(INFO) << "decoder-thread-num: " << s_decoder_thread_num;
    LOG(INFO) << "io-thread-num: " << s_io_thread_num;
runtime/websocket/bin/websocket-server.cpp
@@ -402,11 +402,11 @@
// init asr model
void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
                              int thread_num) {
                              int thread_num, bool use_gpu, int batch_size) {
  try {
    // init model with api
    asr_handle = FunOfflineInit(model_path, thread_num);
    asr_handle = FunOfflineInit(model_path, thread_num, use_gpu, batch_size);
    LOG(INFO) << "model successfully inited";
    
    LOG(INFO) << "initAsr run check_and_clean_connection";
runtime/websocket/bin/websocket-server.h
@@ -124,7 +124,7 @@
                  std::string wav_format,
                  FUNASR_DEC_HANDLE& decoder_handle);
  void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
  void initAsr(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
  void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
  void on_open(websocketpp::connection_hdl hdl);
  void on_close(websocketpp::connection_hdl hdl);