Yabin Li
2023-05-08 cf2ac1ff7b628abe6a7b43d41c3c3e0c1f7f470f
Merge pull request #470 from alibaba-damo-academy/dev_apis

fix wavhead reader; modify punc input to int32; add vad/punc/offline-stream apis; modify option parser
21个文件已修改
8个文件已添加
1025 ■■■■ 已修改文件
funasr/runtime/onnxruntime/include/audio.h 7 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/com-define.h 35 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/libfunasrapi.h 22 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/model.h 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/offline-stream.h 28 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/punc-model.h 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/vad-model.h 27 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/readme.md 95 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/CMakeLists.txt 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/audio.cpp 28 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/commonfunc.h 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/ct-transformer.cpp 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/ct-transformer.h 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/fsmn-vad.cpp 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/fsmn-vad.h 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp 98 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp 32 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp 143 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp 56 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/libfunasrapi.cpp 175 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/model.cpp 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/offline-stream.cpp 61 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer.cpp 76 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer.h 16 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/precomp.h 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/punc-model.cpp 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/tokenizer.cpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/tokenizer.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/vad-model.cpp 21 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/audio.h
@@ -1,10 +1,10 @@
#ifndef AUDIO_H
#define AUDIO_H
#include <queue>
#include <stdint.h>
#include "model.h"
#include "vad-model.h"
#include "offline-stream.h"
#ifndef WAV_HEADER_SIZE
#define WAV_HEADER_SIZE 44
@@ -54,7 +54,8 @@
    int FetchChunck(float *&dout, int len);
    int Fetch(float *&dout, int &len, int &flag);
    void Padding();
    void Split(Model* recog_obj);
    void Split(OfflineStream* offline_streamj);
    void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments);
    float GetTimeLen();
    int GetQueueSize() { return (int)frame_queue.size(); }
};
funasr/runtime/onnxruntime/include/com-define.h
@@ -12,20 +12,37 @@
#define MODEL_SAMPLE_RATE 16000
#endif
// model path
#define VAD_MODEL_PATH "vad-model"
#define VAD_CMVN_PATH "vad-cmvn"
#define VAD_CONFIG_PATH "vad-config"
#define AM_MODEL_PATH "am-model"
#define AM_CMVN_PATH "am-cmvn"
#define AM_CONFIG_PATH "am-config"
#define PUNC_MODEL_PATH "punc-model"
#define PUNC_CONFIG_PATH "punc-config"
// parser option
#define MODEL_DIR "model-dir"
#define VAD_DIR "vad-dir"
#define PUNC_DIR "punc-dir"
#define QUANTIZE "quantize"
#define VAD_QUANT "vad-quant"
#define PUNC_QUANT "punc-quant"
#define WAV_PATH "wav-path"
#define WAV_SCP "wav-scp"
#define TXT_PATH "txt-path"
#define THREAD_NUM "thread-num"
#define PORT_ID "port-id"
// #define VAD_MODEL_PATH "vad-model"
// #define VAD_CMVN_PATH "vad-cmvn"
// #define VAD_CONFIG_PATH "vad-config"
// #define AM_MODEL_PATH "am-model"
// #define AM_CMVN_PATH "am-cmvn"
// #define AM_CONFIG_PATH "am-config"
// #define PUNC_MODEL_PATH "punc-model"
// #define PUNC_CONFIG_PATH "punc-config"
#define MODEL_NAME "model.onnx"
#define QUANT_MODEL_NAME "model_quant.onnx"
#define VAD_CMVN_NAME "vad.mvn"
#define VAD_CONFIG_NAME "vad.yaml"
#define AM_CMVN_NAME "am.mvn"
#define AM_CONFIG_NAME "config.yaml"
#define PUNC_CONFIG_NAME "punc.yaml"
// vad
#ifndef VAD_SILENCE_DURATION
#define VAD_SILENCE_DURATION 800
funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -1,5 +1,6 @@
#pragma once
#include <map>
#include <vector>
#ifdef WIN32
#ifdef _FUNASR_API_EXPORT
@@ -47,7 +48,7 @@
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
    
// // ASR
// ASR
_FUNASRAPI FUNASR_HANDLE  FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI FUNASR_RESULT    FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
@@ -64,10 +65,21 @@
// VAD
_FUNASRAPI FUNASR_HANDLE  FunVadInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI FUNASR_RESULT    FunASRVadBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback);
_FUNASRAPI FUNASR_RESULT    FunASRVadPCMBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
_FUNASRAPI FUNASR_RESULT    FunASRVadPCMFile(FUNASR_HANDLE handle, const char* sz_filename, int sampling_rate, FUNASR_MODE mode, QM_CALLBACK fn_callback);
_FUNASRAPI FUNASR_RESULT    FunASRVadFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
_FUNASRAPI FUNASR_RESULT    FunVadWavFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
_FUNASRAPI std::vector<std::vector<int>>*    FunVadGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI void                 FunVadFreeResult(FUNASR_RESULT result);
_FUNASRAPI void                FunVadUninit(FUNASR_HANDLE handle);
_FUNASRAPI const float        FunVadGetRetSnippetTime(FUNASR_RESULT result);
// PUNC
_FUNASRAPI FUNASR_HANDLE          FunPuncInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI const std::string    FunPuncInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback);
_FUNASRAPI void                    FunPuncUninit(FUNASR_HANDLE handle);
//OfflineStream
_FUNASRAPI FUNASR_HANDLE      FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI FUNASR_RESULT     FunOfflineStream(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback);
_FUNASRAPI void                FunOfflineUninit(FUNASR_HANDLE handle);
#ifdef __cplusplus 
funasr/runtime/onnxruntime/include/model.h
@@ -9,13 +9,10 @@
  public:
    virtual ~Model(){};
    virtual void Reset() = 0;
    virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num)=0;
    virtual std::string ForwardChunk(float *din, int len, int flag) = 0;
    virtual std::string Forward(float *din, int len, int flag) = 0;
    virtual std::string Rescoring() = 0;
    virtual std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data)=0;
    virtual std::string AddPunc(const char* sz_input)=0;
    virtual bool UseVad() =0;
    virtual bool UsePunc() =0;
};
Model *CreateModel(std::map<std::string, std::string>& model_path,int thread_num=1);
funasr/runtime/onnxruntime/include/offline-stream.h
New file
@@ -0,0 +1,28 @@
#ifndef OFFLINE_STREAM_H
#define OFFLINE_STREAM_H
#include <memory>
#include <string>
#include <map>
#include "model.h"
#include "punc-model.h"
#include "vad-model.h"
class OfflineStream {
  public:
    OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
    ~OfflineStream(){};
    std::unique_ptr<VadModel> vad_handle;
    std::unique_ptr<Model> asr_handle;
    std::unique_ptr<PuncModel> punc_handle;
    bool UseVad(){return use_vad;};
    bool UsePunc(){return use_punc;};
  private:
    bool use_vad=false;
    bool use_punc=false;
};
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1);
#endif
funasr/runtime/onnxruntime/include/punc-model.h
New file
@@ -0,0 +1,18 @@
#ifndef PUNC_MODEL_H
#define PUNC_MODEL_H
#include <string>
#include <map>
#include <vector>
class PuncModel {
  public:
    virtual ~PuncModel(){};
      virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0;
      virtual std::vector<int>  Infer(std::vector<int32_t> input_data)=0;
      virtual std::string AddPunc(const char* sz_input)=0;
};
PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num);
#endif
funasr/runtime/onnxruntime/include/vad-model.h
New file
@@ -0,0 +1,27 @@
#ifndef VAD_MODEL_H
#define VAD_MODEL_H
#include <string>
#include <map>
#include <vector>
class VadModel {
  public:
    virtual ~VadModel(){};
    virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
    virtual std::vector<std::vector<int>> Infer(const std::vector<float> &waves)=0;
    virtual void ReadModel(const char* vad_model)=0;
    virtual void LoadConfigFromYaml(const char* filename)=0;
    virtual void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
                    const std::vector<float> &waves)=0;
    virtual std::vector<std::vector<float>> &LfrCmvn(std::vector<std::vector<float>> &vad_feats)=0;
    virtual void Forward(
            const std::vector<std::vector<float>> &chunk_feats,
            std::vector<std::vector<float>> *out_prob)=0;
    virtual void LoadCmvn(const char *filename)=0;
    virtual void InitCache()=0;
};
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
#endif
funasr/runtime/onnxruntime/readme.md
@@ -41,41 +41,88 @@
```
## Run the demo
### funasr-onnx-offline
```shell
./funasr-onnx-offline     [--wav-scp <string>] [--wav-path <string>]
                          [--punc-config <string>] [--punc-model <string>]
                          --am-config <string> --am-cmvn <string>
                          --am-model <string> [--vad-config <string>]
                          [--vad-cmvn <string>] [--vad-model <string>] [--]
                          [--version] [-h]
                          [--punc-quant <string>] [--punc-dir <string>]
                          [--vad-quant <string>] [--vad-dir <string>]
                          [--quantize <string>] --model-dir <string>
                          [--] [--version] [-h]
Where:
   --model-dir <string>
     (required)  the asr model path, which contains model.onnx, config.yaml, am.mvn
   --quantize <string>
     false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
   --vad-dir <string>
     the vad model path, which contains model.onnx, vad.yaml, vad.mvn
   --vad-quant <string>
     false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
   --punc-dir <string>
     the punc model path, which contains model.onnx, punc.yaml
   --punc-quant <string>
     false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
   --wav-scp <string>
     wave scp path
   --wav-path <string>
     wave file path
   --punc-config <string>
     punc config path
   --punc-model <string>
     punc model path
   Required: --model-dir <string>
   If use vad, please add: --vad-dir <string>
   If use punc, please add: --punc-dir <string>
   --am-config <string>
     (required)  am config path
   --am-cmvn <string>
     (required)  am cmvn path
   --am-model <string>
     (required)  am model path
For example:
./funasr-onnx-offline \
    --model-dir    ./asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch \
    --quantize  true \
    --vad-dir   ./asrmodel/speech_fsmn_vad_zh-cn-16k-common-pytorch \
    --punc-dir  ./asrmodel/punc_ct-transformer_zh-cn-common-vocab272727-pytorch \
    --wav-path    ./vad_example.wav
```
   --vad-config <string>
     vad config path
   --vad-cmvn <string>
     vad cmvn path
   --vad-model <string>
     vad model path
### funasr-onnx-offline-vad
```shell
./funasr-onnx-offline-vad     [--wav-scp <string>] [--wav-path <string>]
                              [--quantize <string>] --model-dir <string>
                              [--] [--version] [-h]
Where:
   --model-dir <string>
     (required)  the vad model path, which contains model.onnx, vad.yaml, vad.mvn
   --quantize <string>
     false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
   --wav-scp <string>
     wave scp path
   --wav-path <string>
     wave file path
  
   Required: --am-config <string> --am-cmvn <string> --am-model <string>
   If use vad, please add: [--vad-config <string>] [--vad-cmvn <string>] [--vad-model <string>]
   If use punc, please add: [--punc-config <string>] [--punc-model <string>]
   Required: --model-dir <string>
For example:
./funasr-onnx-offline-vad \
    --model-dir   ./asrmodel/speech_fsmn_vad_zh-cn-16k-common-pytorch \
    --wav-path    ./vad_example.wav
```
### funasr-onnx-offline-punc
```shell
./funasr-onnx-offline-punc    [--txt-path <string>] [--quantize <string>]
                               --model-dir <string> [--] [--version] [-h]
Where:
   --model-dir <string>
     (required)  the punc model path, which contains model.onnx, punc.yaml
   --quantize <string>
     false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
   --txt-path <string>
     txt file path, one sentence per line
   Required: --model-dir <string>
For example:
./funasr-onnx-offline-punc \
    --model-dir  ./asrmodel/punc_ct-transformer_zh-cn-common-vocab272727-pytorch \
    --txt-path   ./punc_example.txt
```
## Acknowledge
funasr/runtime/onnxruntime/src/CMakeLists.txt
@@ -26,6 +26,11 @@
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
funasr/runtime/onnxruntime/src/audio.cpp
@@ -247,6 +247,15 @@
        return false;
    }
    
    if (!header.Validate()) {
        return false;
    }
    header.SeekToDataChunk(is);
    if (!is) {
        return false;
    }
    *sampling_rate = header.sample_rate;
    // header.subchunk2_size contains the number of bytes in the data.
    // As we assume each sample contains two bytes, so it is divided by 2 here
@@ -505,7 +514,7 @@
    delete frame;
}
void Audio::Split(Model* recog_obj)
void Audio::Split(OfflineStream* offline_stream)
{
    AudioFrame *frame;
@@ -516,7 +525,7 @@
    frame = NULL;
    std::vector<float> pcm_data(speech_data, speech_data+sp_len);
    vector<std::vector<int>> vad_segments = recog_obj->VadSeg(pcm_data);
    vector<std::vector<int>> vad_segments = (offline_stream->vad_handle)->Infer(pcm_data);
    int seg_sample = MODEL_SAMPLE_RATE/1000;
    for(vector<int> segment:vad_segments)
    {
@@ -529,3 +538,18 @@
        frame = NULL;
    }
}
void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments)
{
    AudioFrame *frame;
    frame = frame_queue.front();
    frame_queue.pop();
    int sp_len = frame->GetLen();
    delete frame;
    frame = NULL;
    std::vector<float> pcm_data(speech_data, speech_data+sp_len);
    vad_segments = vad_obj->Infer(pcm_data);
}
funasr/runtime/onnxruntime/src/commonfunc.h
@@ -6,6 +6,12 @@
    float  snippet_time;
}FUNASR_RECOG_RESULT;
typedef struct
{
    std::vector<std::vector<int>>* segments;
    float  snippet_time;
}FUNASR_VAD_RESULT;
#ifdef _WIN32
#include <codecvt>
funasr/runtime/onnxruntime/src/ct-transformer.cpp
@@ -54,7 +54,7 @@
    int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN);
    int nCurBatch = -1;
    int nSentEnd = -1, nLastCommaIndex = -1;
    vector<int64_t> RemainIDs; //
    vector<int32_t> RemainIDs; //
    vector<string> RemainStr; //
    vector<int> NewPunctuation; //
    vector<string> NewString; //
@@ -64,7 +64,7 @@
    for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN)
    {
        nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size());
        vector<int64_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
        vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
        vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff);
        InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs;
        InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr;
@@ -141,12 +141,13 @@
    return strResult;
}
vector<int> CTTransformer::Infer(vector<int64_t> input_data)
vector<int> CTTransformer::Infer(vector<int32_t> input_data)
{
    Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    vector<int> punction;
    std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()};
    Ort::Value onnx_input = Ort::Value::CreateTensor<int64_t>(m_memoryInfo,
    Ort::Value onnx_input = Ort::Value::CreateTensor<int32_t>(
        m_memoryInfo,
        input_data.data(),
        input_data.size(),
        input_shape_.data(),
funasr/runtime/onnxruntime/src/ct-transformer.h
@@ -5,7 +5,7 @@
#pragma once 
class CTTransformer {
class CTTransformer : public PuncModel {
/**
 * Author: Speech Lab of DAMO Academy, Alibaba Group
 * CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -27,6 +27,6 @@
    CTTransformer();
    void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
    ~CTTransformer();
    vector<int>  Infer(vector<int64_t> input_data);
    vector<int>  Infer(vector<int32_t> input_data);
    string AddPunc(const char* sz_input);
};
funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -6,8 +6,8 @@
#include <fstream>
#include "precomp.h"
void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config) {
    session_options_.SetIntraOpNumThreads(1);
void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num) {
    session_options_.SetIntraOpNumThreads(thread_num);
    session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
    session_options_.DisableCpuMemArena();
@@ -296,5 +296,8 @@
void FsmnVad::Test() {
}
FsmnVad::~FsmnVad() {
}
FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} {
}
funasr/runtime/onnxruntime/src/fsmn-vad.h
@@ -8,7 +8,7 @@
#include "precomp.h"
class FsmnVad {
class FsmnVad : public VadModel {
/**
 * Author: Speech Lab of DAMO Academy, Alibaba Group
 * Deep-FSMN for Large Vocabulary Continuous Speech Recognition
@@ -17,9 +17,9 @@
public:
    FsmnVad();
    ~FsmnVad();
    void Test();
    void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config);
    void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num);
    std::vector<std::vector<int>> Infer(const std::vector<float> &waves);
    void Reset();
funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp
New file
@@ -0,0 +1,98 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#ifndef _WIN32
#include <sys/time.h>
#else
#include <win_func.h>
#endif
#include <iostream>
#include <fstream>
#include <sstream>
#include <map>
#include <glog/logging.h>
#include "libfunasrapi.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
using namespace std;
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
{
    if (value_arg.isSet()){
        model_path.insert({key, value_arg.getValue()});
        LOG(INFO)<< key << " : " << value_arg.getValue();
    }
}
int main(int argc, char *argv[])
{
    google::InitGoogleLogging(argv[0]);
    FLAGS_logtostderr = true;
    TCLAP::CmdLine cmd("funasr-onnx-offline-punc", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", false, "", "string");
    cmd.add(model_dir);
    cmd.add(quantize);
    cmd.add(txt_path);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(txt_path, TXT_PATH, model_path);
    struct timeval start, end;
    gettimeofday(&start, NULL);
    int thread_num = 1;
    FUNASR_HANDLE punc_hanlde=FunPuncInit(model_path, thread_num);
    if (!punc_hanlde)
    {
        LOG(ERROR) << "FunASR init failed";
        exit(-1);
    }
    gettimeofday(&end, NULL);
    long seconds = (end.tv_sec - start.tv_sec);
    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
    LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
    // read txt_path
    vector<string> txt_list;
    if(model_path.find(TXT_PATH)!=model_path.end()){
        ifstream in(model_path.at(TXT_PATH));
        if (!in.is_open()) {
            LOG(ERROR) << "Failed to open file: " << model_path.at(TXT_PATH) ;
            return 0;
        }
        string line;
        while(getline(in, line))
        {
            txt_list.emplace_back(line);
        }
        in.close();
    }
    long taking_micros = 0;
    for(auto& txt_str : txt_list){
        gettimeofday(&start, NULL);
        string result=FunPuncInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL);
        gettimeofday(&end, NULL);
        seconds = (end.tv_sec - start.tv_sec);
        taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
        LOG(INFO)<<"Results: "<<result;
    }
    LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
    FunPuncUninit(punc_hanlde);
    return 0;
}
funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
@@ -91,41 +91,21 @@
    FLAGS_logtostderr = true;
    TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0");
    TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
    TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
    TCLAP::ValueArg<std::string> vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string");
    TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", false, "", "string");
    TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", false, "", "string");
    TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", false, "", "string");
    TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
    TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string> wav_scp("", WAV_SCP, "wave scp path", true, "", "string");
    TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
    cmd.add(vad_model);
    cmd.add(vad_cmvn);
    cmd.add(vad_config);
    cmd.add(am_model);
    cmd.add(am_cmvn);
    cmd.add(am_config);
    cmd.add(punc_model);
    cmd.add(punc_config);
    cmd.add(model_dir);
    cmd.add(quantize);
    cmd.add(wav_scp);
    cmd.add(thread_num);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(vad_model, VAD_MODEL_PATH, model_path);
    GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
    GetValue(vad_config, VAD_CONFIG_PATH, model_path);
    GetValue(am_model, AM_MODEL_PATH, model_path);
    GetValue(am_cmvn, AM_CMVN_PATH, model_path);
    GetValue(am_config, AM_CONFIG_PATH, model_path);
    GetValue(punc_model, PUNC_MODEL_PATH, model_path);
    GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(wav_scp, WAV_SCP, model_path);
    struct timeval start, end;
funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
New file
@@ -0,0 +1,143 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#ifndef _WIN32
#include <sys/time.h>
#else
#include <win_func.h>
#endif
#include <iostream>
#include <fstream>
#include <sstream>
#include <map>
#include <vector>
#include <glog/logging.h>
#include "libfunasrapi.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
using namespace std;
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
{
    if (value_arg.isSet()){
        model_path.insert({key, value_arg.getValue()});
        LOG(INFO)<< key << " : " << value_arg.getValue();
    }
}
void print_segs(vector<vector<int>>* vec) {
    string seg_out="[";
    for (int i = 0; i < vec->size(); i++) {
        vector<int> inner_vec = (*vec)[i];
        seg_out += "[";
        for (int j = 0; j < inner_vec.size(); j++) {
            seg_out += to_string(inner_vec[j]);
            if (j != inner_vec.size() - 1) {
                seg_out += ",";
            }
        }
        seg_out += "]";
        if (i != vec->size() - 1) {
            seg_out += ",";
        }
    }
    seg_out += "]";
    LOG(INFO)<<seg_out;
}
int main(int argc, char *argv[])
{
    google::InitGoogleLogging(argv[0]);
    FLAGS_logtostderr = true;
    TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "wave file path", false, "", "string");
    TCLAP::ValueArg<std::string> wav_scp("", WAV_SCP, "wave scp path", false, "", "string");
    cmd.add(model_dir);
    cmd.add(quantize);
    cmd.add(wav_path);
    cmd.add(wav_scp);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(wav_path, WAV_PATH, model_path);
    GetValue(wav_scp, WAV_SCP, model_path);
    struct timeval start, end;
    gettimeofday(&start, NULL);
    int thread_num = 1;
    FUNASR_HANDLE vad_hanlde=FunVadInit(model_path, thread_num);
    if (!vad_hanlde)
    {
        LOG(ERROR) << "FunVad init failed";
        exit(-1);
    }
    gettimeofday(&end, NULL);
    long seconds = (end.tv_sec - start.tv_sec);
    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
    LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
    // read wav_path and wav_scp
    vector<string> wav_list;
    if(model_path.find(WAV_PATH)!=model_path.end()){
        wav_list.emplace_back(model_path.at(WAV_PATH));
    }
    if(model_path.find(WAV_SCP)!=model_path.end()){
        ifstream in(model_path.at(WAV_SCP));
        if (!in.is_open()) {
            LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
            return 0;
        }
        string line;
        while(getline(in, line))
        {
            istringstream iss(line);
            string column1, column2;
            iss >> column1 >> column2;
            wav_list.emplace_back(column2);
        }
        in.close();
    }
    float snippet_time = 0.0f;
    long taking_micros = 0;
    for(auto& wav_file : wav_list){
        gettimeofday(&start, NULL);
        FUNASR_RESULT result=FunVadWavFile(vad_hanlde, wav_file.c_str(), RASR_NONE, NULL);
        gettimeofday(&end, NULL);
        seconds = (end.tv_sec - start.tv_sec);
        taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
        if (result)
        {
            vector<std::vector<int>>* vad_segments = FunVadGetResult(result, 0);
            print_segs(vad_segments);
            snippet_time += FunVadGetRetSnippetTime(result);
            FunVadFreeResult(result);
        }
        else
        {
            LOG(ERROR) << ("No return data!\n");
        }
    }
    LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
    LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
    LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
    FunVadUninit(vad_hanlde);
    return 0;
}
funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
@@ -28,55 +28,46 @@
    }
}
int main(int argc, char *argv[])
int main(int argc, char** argv)
{
    google::InitGoogleLogging(argv[0]);
    FLAGS_logtostderr = true;
    TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
    TCLAP::ValueArg<std::string> vad_model("", VAD_MODEL_PATH, "vad model path", false, "", "string");
    TCLAP::ValueArg<std::string> vad_cmvn("", VAD_CMVN_PATH, "vad cmvn path", false, "", "string");
    TCLAP::ValueArg<std::string> vad_config("", VAD_CONFIG_PATH, "vad config path", false, "", "string");
    TCLAP::ValueArg<std::string> am_model("", AM_MODEL_PATH, "am model path", true, "", "string");
    TCLAP::ValueArg<std::string> am_cmvn("", AM_CMVN_PATH, "am cmvn path", true, "", "string");
    TCLAP::ValueArg<std::string> am_config("", AM_CONFIG_PATH, "am config path", true, "", "string");
    TCLAP::ValueArg<std::string> punc_model("", PUNC_MODEL_PATH, "punc model path", false, "", "string");
    TCLAP::ValueArg<std::string> punc_config("", PUNC_CONFIG_PATH, "punc config path", false, "", "string");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
    TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "wave file path", false, "", "string");
    TCLAP::ValueArg<std::string> wav_scp("", WAV_SCP, "wave scp path", false, "", "string");
    cmd.add(vad_model);
    cmd.add(vad_cmvn);
    cmd.add(vad_config);
    cmd.add(am_model);
    cmd.add(am_cmvn);
    cmd.add(am_config);
    cmd.add(punc_model);
    cmd.add(punc_config);
    cmd.add(model_dir);
    cmd.add(quantize);
    cmd.add(vad_dir);
    cmd.add(vad_quant);
    cmd.add(punc_dir);
    cmd.add(punc_quant);
    cmd.add(wav_path);
    cmd.add(wav_scp);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
    GetValue(vad_model, VAD_MODEL_PATH, model_path);
    GetValue(vad_cmvn, VAD_CMVN_PATH, model_path);
    GetValue(vad_config, VAD_CONFIG_PATH, model_path);
    GetValue(am_model, AM_MODEL_PATH, model_path);
    GetValue(am_cmvn, AM_CMVN_PATH, model_path);
    GetValue(am_config, AM_CONFIG_PATH, model_path);
    GetValue(punc_model, PUNC_MODEL_PATH, model_path);
    GetValue(punc_config, PUNC_CONFIG_PATH, model_path);
    GetValue(model_dir, MODEL_DIR, model_path);
    GetValue(quantize, QUANTIZE, model_path);
    GetValue(vad_dir, VAD_DIR, model_path);
    GetValue(vad_quant, VAD_QUANT, model_path);
    GetValue(punc_dir, PUNC_DIR, model_path);
    GetValue(punc_quant, PUNC_QUANT, model_path);
    GetValue(wav_path, WAV_PATH, model_path);
    GetValue(wav_scp, WAV_SCP, model_path);
    struct timeval start, end;
    gettimeofday(&start, NULL);
    int thread_num = 1;
    FUNASR_HANDLE asr_hanlde=FunASRInit(model_path, thread_num);
    FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
    if (!asr_hanlde)
    {
@@ -116,7 +107,7 @@
    long taking_micros = 0;
    for(auto& wav_file : wav_list){
        gettimeofday(&start, NULL);
        FUNASR_RESULT result=FunASRRecogFile(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL);
        FUNASR_RESULT result=FunOfflineStream(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL);
        gettimeofday(&end, NULL);
        seconds = (end.tv_sec - start.tv_sec);
        taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
@@ -124,8 +115,7 @@
        if (result)
        {
            string msg = FunASRGetResult(result, 0);
            setbuf(stdout, NULL);
            printf("Result: %s \n", msg.c_str());
            LOG(INFO)<<"Result: "<<msg;
            snippet_time += FunASRGetRetSnippetTime(result);
            FunASRFreeResult(result);
        }
@@ -138,7 +128,7 @@
    LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
    LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
    LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
    FunASRUninit(asr_hanlde);
    FunOfflineUninit(asr_hanlde);
    return 0;
}
funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -4,7 +4,7 @@
extern "C" {
#endif
    // APIs for funasr
    // APIs for Init
    _FUNASRAPI FUNASR_HANDLE  FunASRInit(std::map<std::string, std::string>& model_path, int thread_num)
    {
        Model* mm = CreateModel(model_path, thread_num);
@@ -13,10 +13,23 @@
    _FUNASRAPI FUNASR_HANDLE  FunVadInit(std::map<std::string, std::string>& model_path, int thread_num)
    {
        Model* mm = CreateModel(model_path, thread_num);
        VadModel* mm = CreateVadModel(model_path, thread_num);
        return mm;
    }
    _FUNASRAPI FUNASR_HANDLE  FunPuncInit(std::map<std::string, std::string>& model_path, int thread_num)
    {
        PuncModel* mm = CreatePuncModel(model_path, thread_num);
        return mm;
    }
    _FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num)
    {
        OfflineStream* mm = CreateOfflineStream(model_path, thread_num);
        return mm;
    }
    // APIs for ASR Infer
    _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback)
    {
        Model* recog_obj = (Model*)handle;
@@ -27,9 +40,6 @@
        Audio audio(1);
        if (!audio.LoadWav(sz_buf, n_len, &sampling_rate))
            return nullptr;
        if(recog_obj->UseVad()){
            audio.Split(recog_obj);
        }
        float* buff;
        int len;
@@ -44,10 +54,6 @@
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        if(recog_obj->UsePunc()){
            string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
            p_result->msg = punc_res;
        }
        return p_result;
@@ -62,9 +68,6 @@
        Audio audio(1);
        if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
            return nullptr;
        if(recog_obj->UseVad()){
            audio.Split(recog_obj);
        }
        float* buff;
        int len;
@@ -79,10 +82,6 @@
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        if(recog_obj->UsePunc()){
            string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
            p_result->msg = punc_res;
        }
        return p_result;
@@ -97,9 +96,6 @@
        Audio audio(1);
        if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
            return nullptr;
        if(recog_obj->UseVad()){
            audio.Split(recog_obj);
        }
        float* buff;
        int len;
@@ -114,10 +110,6 @@
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        if(recog_obj->UsePunc()){
            string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
            p_result->msg = punc_res;
        }
        return p_result;
@@ -133,9 +125,6 @@
        Audio audio(1);
        if(!audio.LoadWav(sz_wavfile, &sampling_rate))
            return nullptr;
        if(recog_obj->UseVad()){
            audio.Split(recog_obj);
        }
        float* buff;
        int len;
@@ -151,8 +140,74 @@
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        if(recog_obj->UsePunc()){
            string punc_res = recog_obj->AddPunc((p_result->msg).c_str());
        return p_result;
    }
    // APIs for VAD Infer
    _FUNASRAPI FUNASR_RESULT FunVadWavFile(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback)
    {
        VadModel* vad_obj = (VadModel*)handle;
        if (!vad_obj)
            return nullptr;
        int32_t sampling_rate = -1;
        Audio audio(1);
        if(!audio.LoadWav(sz_wavfile, &sampling_rate))
            return nullptr;
        FUNASR_VAD_RESULT* p_result = new FUNASR_VAD_RESULT;
        p_result->snippet_time = audio.GetTimeLen();
        vector<std::vector<int>> vad_segments;
        audio.Split(vad_obj, vad_segments);
        p_result->segments = new vector<std::vector<int>>(vad_segments);
        return p_result;
    }
    // APIs for PUNC Infer
    _FUNASRAPI const std::string FunPuncInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback)
    {
        PuncModel* punc_obj = (PuncModel*)handle;
        if (!punc_obj)
            return nullptr;
        string punc_res = punc_obj->AddPunc(sz_sentence);
        return punc_res;
    }
    // APIs for Offline-stream Infer
    _FUNASRAPI FUNASR_RESULT FunOfflineStream(FUNASR_HANDLE handle, const char* sz_wavfile, FUNASR_MODE mode, QM_CALLBACK fn_callback)
    {
        OfflineStream* offline_stream = (OfflineStream*)handle;
        if (!offline_stream)
            return nullptr;
        int32_t sampling_rate = -1;
        Audio audio(1);
        if(!audio.LoadWav(sz_wavfile, &sampling_rate))
            return nullptr;
        if(offline_stream->UseVad()){
            audio.Split(offline_stream);
        }
        float* buff;
        int len;
        int flag = 0;
        int n_step = 0;
        int n_total = audio.GetQueueSize();
        FUNASR_RECOG_RESULT* p_result = new FUNASR_RECOG_RESULT;
        p_result->snippet_time = audio.GetTimeLen();
        while (audio.Fetch(buff, len, flag) > 0) {
            string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
            p_result->msg+= msg;
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
        }
        if(offline_stream->UsePunc()){
            string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
            p_result->msg = punc_res;
        }
    
@@ -167,7 +222,7 @@
        return 1;
    }
    // APIs for GetRetSnippetTime
    _FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result)
    {
        if (!result)
@@ -176,6 +231,15 @@
        return ((FUNASR_RECOG_RESULT*)result)->snippet_time;
    }
    _FUNASRAPI const float FunVadGetRetSnippetTime(FUNASR_RESULT result)
    {
        if (!result)
            return 0.0f;
        return ((FUNASR_VAD_RESULT*)result)->snippet_time;
    }
    // APIs for GetResult
    _FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index)
    {
        FUNASR_RECOG_RESULT * p_result = (FUNASR_RECOG_RESULT*)result;
@@ -185,6 +249,16 @@
        return p_result->msg.c_str();
    }
    _FUNASRAPI vector<std::vector<int>>* FunVadGetResult(FUNASR_RESULT result,int n_index)
    {
        FUNASR_VAD_RESULT * p_result = (FUNASR_VAD_RESULT*)result;
        if(!p_result)
            return nullptr;
        return p_result->segments;
    }
    // APIs for FreeResult
    _FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result)
    {
        if (result)
@@ -193,6 +267,19 @@
        }
    }
    _FUNASRAPI void FunVadFreeResult(FUNASR_RESULT result)
    {
        FUNASR_VAD_RESULT * p_result = (FUNASR_VAD_RESULT*)result;
        if (p_result)
        {
            if(p_result->segments){
                delete p_result->segments;
            }
            delete p_result;
        }
    }
    // APIs for Uninit
    _FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle)
    {
        Model* recog_obj = (Model*)handle;
@@ -203,6 +290,36 @@
        delete recog_obj;
    }
    _FUNASRAPI void FunVadUninit(FUNASR_HANDLE handle)
    {
        VadModel* recog_obj = (VadModel*)handle;
        if (!recog_obj)
            return;
        delete recog_obj;
    }
    _FUNASRAPI void FunPuncUninit(FUNASR_HANDLE handle)
    {
        PuncModel* punc_obj = (PuncModel*)handle;
        if (!punc_obj)
            return;
        delete punc_obj;
    }
    _FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle)
    {
        OfflineStream* offline_stream = (OfflineStream*)handle;
        if (!offline_stream)
            return;
        delete offline_stream;
    }
#ifdef __cplusplus 
}
funasr/runtime/onnxruntime/src/model.cpp
@@ -2,7 +2,19 @@
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num)
{
    string am_model_path;
    string am_cmvn_path;
    string am_config_path;
    am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
    if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
        am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
    }
    am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
    am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
    Model *mm;
    mm = new paraformer::Paraformer(model_path, thread_num);
    mm = new paraformer::Paraformer();
    mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
    return mm;
}
funasr/runtime/onnxruntime/src/offline-stream.cpp
New file
@@ -0,0 +1,61 @@
#include "precomp.h"
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
{
    // VAD model
    if(model_path.find(VAD_DIR) != model_path.end()){
        use_vad = true;
        string vad_model_path;
        string vad_cmvn_path;
        string vad_config_path;
        vad_model_path = PathAppend(model_path.at(VAD_DIR), MODEL_NAME);
        if(model_path.find(VAD_QUANT) != model_path.end() && model_path.at(VAD_QUANT) == "true"){
            vad_model_path = PathAppend(model_path.at(VAD_DIR), QUANT_MODEL_NAME);
        }
        vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
        vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
        vad_handle = make_unique<FsmnVad>();
        vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
    }
    // AM model
    if(model_path.find(MODEL_DIR) != model_path.end()){
        string am_model_path;
        string am_cmvn_path;
        string am_config_path;
        am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
        if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
            am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
        }
        am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
        am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
        asr_handle = make_unique<Paraformer>();
        asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
    }
    // PUNC model
    if(model_path.find(PUNC_DIR) != model_path.end()){
        use_punc = true;
        string punc_model_path;
        string punc_config_path;
        punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
        if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
            punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
        }
        punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
        punc_handle = make_unique<CTTransformer>();
        punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
    }
}
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
{
    OfflineStream *mm;
    mm = new OfflineStream(model_path, thread_num);
    return mm;
}
funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -8,65 +8,11 @@
using namespace std;
using namespace paraformer;
Paraformer::Paraformer(std::map<std::string, std::string>& model_path,int thread_num)
Paraformer::Paraformer()
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
    // VAD model
    if(model_path.find(VAD_MODEL_PATH) != model_path.end()){
        use_vad = true;
        string vad_model_path;
        string vad_cmvn_path;
        string vad_config_path;
        try{
            vad_model_path = model_path.at(VAD_MODEL_PATH);
            vad_cmvn_path = model_path.at(VAD_CMVN_PATH);
            vad_config_path = model_path.at(VAD_CONFIG_PATH);
        }catch(const out_of_range& e){
            LOG(ERROR) << "Error when read "<< VAD_CMVN_PATH << " or " << VAD_CONFIG_PATH <<" :" << e.what();
            exit(0);
        }
        vad_handle = make_unique<FsmnVad>();
        vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path);
    }
    // AM model
    if(model_path.find(AM_MODEL_PATH) != model_path.end()){
        string am_model_path;
        string am_cmvn_path;
        string am_config_path;
        try{
            am_model_path = model_path.at(AM_MODEL_PATH);
            am_cmvn_path = model_path.at(AM_CMVN_PATH);
            am_config_path = model_path.at(AM_CONFIG_PATH);
        }catch(const out_of_range& e){
            LOG(ERROR) << "Error when read "<< AM_CONFIG_PATH << " or " << AM_CMVN_PATH <<" :" << e.what();
            exit(0);
        }
        InitAM(am_model_path, am_cmvn_path, am_config_path, thread_num);
    }
    // PUNC model
    if(model_path.find(PUNC_MODEL_PATH) != model_path.end()){
        use_punc = true;
        string punc_model_path;
        string punc_config_path;
        try{
            punc_model_path = model_path.at(PUNC_MODEL_PATH);
            punc_config_path = model_path.at(PUNC_CONFIG_PATH);
        }catch(const out_of_range& e){
            LOG(ERROR) << "Error when read "<< PUNC_CONFIG_PATH <<" :" << e.what();
            exit(0);
        }
        punc_handle = make_unique<CTTransformer>();
        punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
    }
}
void Paraformer::InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
    // knf options
    fbank_opts.frame_opts.dither = 0;
    fbank_opts.mel_opts.num_bins = 80;
@@ -118,14 +64,6 @@
void Paraformer::Reset()
{
}
vector<std::vector<int>> Paraformer::VadSeg(std::vector<float>& pcm_data){
    return vad_handle->Infer(pcm_data);
}
string Paraformer::AddPunc(const char* sz_input){
    return punc_handle->AddPunc(sz_input);
}
vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
@@ -282,7 +220,7 @@
    }
    catch (std::exception const &e)
    {
        printf(e.what());
        LOG(ERROR)<<e.what();
    }
    return result;
@@ -291,12 +229,12 @@
string Paraformer::ForwardChunk(float* din, int len, int flag)
{
    printf("Not Imp!!!!!!\n");
    return "Hello";
    LOG(ERROR)<<"Not Imp!!!!!!";
    return "";
}
string Paraformer::Rescoring()
{
    printf("Not Imp!!!!!!\n");
    return "Hello";
    LOG(ERROR)<<"Not Imp!!!!!!";
    return "";
}
funasr/runtime/onnxruntime/src/paraformer.h
@@ -2,9 +2,7 @@
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#pragma once
#ifndef PARAFORMER_MODELIMP_H
#define PARAFORMER_MODELIMP_H
@@ -23,9 +21,6 @@
        //std::unique_ptr<knf::OnlineFbank> fbank_;
        knf::FbankOptions fbank_opts;
        std::unique_ptr<FsmnVad> vad_handle;
        std::unique_ptr<CTTransformer> punc_handle;
        Vocab* vocab;
        vector<float> means_list;
        vector<float> vars_list;
@@ -36,7 +31,6 @@
        void LoadCmvn(const char *filename);
        vector<float> ApplyLfr(const vector<float> &in);
        void ApplyCmvn(vector<float> *v);
        string GreedySearch( float* in, int n_len, int64_t token_nums);
        std::shared_ptr<Ort::Session> m_session;
@@ -46,22 +40,16 @@
        vector<string> m_strInputNames, m_strOutputNames;
        vector<const char*> m_szInputNames;
        vector<const char*> m_szOutputNames;
        bool use_vad=false;
        bool use_punc=false;
    public:
        Paraformer(std::map<std::string, std::string>& model_path, int thread_num=0);
        Paraformer();
        ~Paraformer();
        void InitAM(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
        void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
        void Reset();
        vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
        string ForwardChunk(float* din, int len, int flag);
        string Forward(float* din, int len, int flag);
        string Rescoring();
        std::vector<std::vector<int>> VadSeg(std::vector<float>& pcm_data);
        string AddPunc(const char* sz_input);
        bool UseVad(){return use_vad;};
        bool UsePunc(){return use_punc;};
    };
} // namespace paraformer
funasr/runtime/onnxruntime/src/precomp.h
@@ -30,6 +30,10 @@
#include "com-define.h"
#include "commonfunc.h"
#include "predefine-coe.h"
#include "model.h"
#include "vad-model.h"
#include "punc-model.h"
#include "offline-stream.h"
#include "tokenizer.h"
#include "ct-transformer.h"
#include "fsmn-vad.h"
@@ -39,9 +43,8 @@
#include "tensor.h"
#include "util.h"
#include "resample.h"
#include "model.h"
//#include "vad-model.h"
#include "paraformer.h"
#include "offline-stream.h"
#include "libfunasrapi.h"
using namespace paraformer;
funasr/runtime/onnxruntime/src/punc-model.cpp
New file
@@ -0,0 +1,19 @@
#include "precomp.h"
PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num)
{
    PuncModel *mm;
    mm = new CTTransformer();
    string punc_model_path;
    string punc_config_path;
    punc_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
    if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
        punc_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
    }
    punc_config_path = PathAppend(model_path.at(MODEL_DIR), PUNC_CONFIG_NAME);
    mm->InitPunc(punc_model_path, punc_config_path, thread_num);
    return mm;
}
funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -14,6 +14,10 @@
{
}
CTokenizer::~CTokenizer()
{
}
void CTokenizer::ReadYaml(const YAML::Node& node) 
{
    if (node.IsMap()) 
funasr/runtime/onnxruntime/src/tokenizer.h
@@ -17,6 +17,7 @@
    CTokenizer(const char* sz_yamlfile);
    CTokenizer();
    ~CTokenizer();
    bool OpenYaml(const char* sz_yamlfile);
    void ReadYaml(const YAML::Node& node);
    vector<string> Id2String(vector<int> input);
funasr/runtime/onnxruntime/src/vad-model.cpp
New file
@@ -0,0 +1,21 @@
#include "precomp.h"
VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num)
{
    VadModel *mm;
    mm = new FsmnVad();
    string vad_model_path;
    string vad_cmvn_path;
    string vad_config_path;
    vad_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
    if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
        vad_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
    }
    vad_cmvn_path = PathAppend(model_path.at(MODEL_DIR), VAD_CMVN_NAME);
    vad_config_path = PathAppend(model_path.at(MODEL_DIR), VAD_CONFIG_NAME);
    mm->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
    return mm;
}