From 3372b13d24aceef7002cfa0fc8222b3085c15110 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 02 六月 2023 22:02:31 +0800
Subject: [PATCH] add fsmn-vad-online
---
funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp | 0
funasr/runtime/onnxruntime/src/audio.cpp | 78 ++++++
funasr/runtime/onnxruntime/bin/CMakeLists.txt | 16 +
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp | 2
funasr/runtime/onnxruntime/src/CMakeLists.txt | 17 -
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp | 0
funasr/runtime/onnxruntime/src/funasrruntime.cpp | 18 +
funasr/runtime/onnxruntime/src/vad-model.cpp | 15
funasr/runtime/onnxruntime/src/fsmn-vad-online.h | 88 +++++++
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp | 0
funasr/runtime/onnxruntime/src/fsmn-vad.cpp | 51 ++-
funasr/runtime/onnxruntime/src/paraformer.h | 4
/dev/null | 58 ----
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp | 75 ++++-
funasr/runtime/onnxruntime/include/audio.h | 13
funasr/runtime/onnxruntime/include/funasrruntime.h | 13
funasr/runtime/onnxruntime/CMakeLists.txt | 13
funasr/runtime/onnxruntime/src/precomp.h | 3
funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp | 198 ++++++++++++++++
funasr/runtime/onnxruntime/include/vad-model.h | 9
funasr/runtime/onnxruntime/src/fsmn-vad.h | 45 +-
21 files changed, 534 insertions(+), 182 deletions(-)
diff --git a/funasr/runtime/onnxruntime/CMakeLists.txt b/funasr/runtime/onnxruntime/CMakeLists.txt
index 9f6013f..0847d1f 100644
--- a/funasr/runtime/onnxruntime/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/CMakeLists.txt
@@ -7,6 +7,8 @@
# set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
+
include(TestBigEndian)
test_big_endian(BIG_ENDIAN)
@@ -30,12 +32,13 @@
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi-native-fbank)
include_directories(${PROJECT_SOURCE_DIR}/third_party/yaml-cpp/include)
-add_subdirectory(third_party/yaml-cpp)
-add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc)
-add_subdirectory(src)
-
if(ENABLE_GLOG)
include_directories(${PROJECT_SOURCE_DIR}/third_party/glog)
set(BUILD_TESTING OFF)
add_subdirectory(third_party/glog)
-endif()
\ No newline at end of file
+endif()
+
+add_subdirectory(third_party/yaml-cpp)
+add_subdirectory(third_party/kaldi-native-fbank/kaldi-native-fbank/csrc)
+add_subdirectory(src)
+add_subdirectory(bin)
diff --git a/funasr/runtime/onnxruntime/bin/CMakeLists.txt b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
new file mode 100644
index 0000000..962da0b
--- /dev/null
+++ b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
@@ -0,0 +1,16 @@
+include_directories(${CMAKE_SOURCE_DIR}/include)
+
+add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
+target_link_libraries(funasr-onnx-offline PUBLIC funasr)
+
+add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
+target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
+
+add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp")
+target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
+
+add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
+target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
+
+add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
+target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
similarity index 100%
rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline-punc.cpp
rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
similarity index 100%
rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline-rtf.cpp
rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp
similarity index 98%
rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp
index 0f606c6..912630b 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp
@@ -125,7 +125,7 @@
long taking_micros = 0;
for(auto& wav_file : wav_list){
gettimeofday(&start, NULL);
- FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), FSMN_VAD_OFFLINE, NULL, 16000);
+ FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), NULL, 16000);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
similarity index 100%
rename from funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
rename to funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
similarity index 67%
copy from funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
copy to funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
index 0f606c6..d9944a0 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
@@ -18,6 +18,7 @@
#include "funasrruntime.h"
#include "tclap/CmdLine.h"
#include "com-define.h"
+#include "audio.h"
using namespace std;
@@ -39,9 +40,15 @@
}
void print_segs(vector<vector<int>>* vec) {
+ if((*vec).size() == 0){
+ return;
+ }
string seg_out="[";
for (int i = 0; i < vec->size(); i++) {
vector<int> inner_vec = (*vec)[i];
+ if(inner_vec.size() == 0){
+ continue;
+ }
seg_out += "[";
for (int j = 0; j < inner_vec.size(); j++) {
seg_out += to_string(inner_vec[j]);
@@ -120,32 +127,66 @@
LOG(ERROR)<<"Please check the wav extension!";
exit(-1);
}
-
+ // init online features
+ FUNASR_HANDLE online_hanlde=FsmnVadOnlineInit(vad_hanlde);
float snippet_time = 0.0f;
long taking_micros = 0;
for(auto& wav_file : wav_list){
- gettimeofday(&start, NULL);
- FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), FSMN_VAD_OFFLINE, NULL, 16000);
- gettimeofday(&end, NULL);
- seconds = (end.tv_sec - start.tv_sec);
- taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
- if (result)
- {
- vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
- print_segs(vad_segments);
- snippet_time += FsmnVadGetRetSnippetTime(result);
- FsmnVadFreeResult(result);
- }
- else
- {
- LOG(ERROR) << ("No return data!\n");
+ int32_t sampling_rate_ = -1;
+ funasr::Audio audio(1);
+ if(is_target_file(wav_file.c_str(), "wav")){
+ int32_t sampling_rate_ = -1;
+ if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }else if(is_target_file(wav_file.c_str(), "pcm")){
+ if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }else{
+ LOG(ERROR)<<"Wrong wav extension";
+ exit(-1);
+ }
+ char* speech_buff = audio.GetSpeechChar();
+ int buff_len = audio.GetSpeechLen()*2;
+
+ int step = 3200;
+ bool is_final = false;
+
+ for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+ if (sample_offset + step >= buff_len - 1) {
+ step = buff_len - sample_offset;
+ is_final = true;
+ } else {
+ is_final = false;
+ }
+ gettimeofday(&start, NULL);
+ FUNASR_RESULT result = FsmnVadInferBuffer(online_hanlde, speech_buff+sample_offset, step, NULL, is_final, 16000);
+ gettimeofday(&end, NULL);
+ seconds = (end.tv_sec - start.tv_sec);
+ taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+
+ if (result)
+ {
+ vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
+ print_segs(vad_segments);
+ snippet_time += FsmnVadGetRetSnippetTime(result);
+ FsmnVadFreeResult(result);
+ }
+ else
+ {
+ LOG(ERROR) << ("No return data!\n");
+ }
}
}
-
+
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
+ FsmnVadUninit(online_hanlde);
FsmnVadUninit(vad_hanlde);
return 0;
}
diff --git a/funasr/runtime/onnxruntime/include/audio.h b/funasr/runtime/onnxruntime/include/audio.h
index 1eabd3e..d2100a4 100644
--- a/funasr/runtime/onnxruntime/include/audio.h
+++ b/funasr/runtime/onnxruntime/include/audio.h
@@ -33,8 +33,9 @@
class Audio {
private:
- float *speech_data;
- int16_t *speech_buff;
+ float *speech_data=nullptr;
+ int16_t *speech_buff=nullptr;
+ char* speech_char=nullptr;
int speech_len;
int speech_align_len;
int offset;
@@ -47,18 +48,22 @@
Audio(int data_type, int size);
~Audio();
void Disp();
- bool LoadWav(const char* filename, int32_t* sampling_rate);
void WavResample(int32_t sampling_rate, const float *waveform, int32_t n);
bool LoadWav(const char* buf, int n_len, int32_t* sampling_rate);
+ bool LoadWav(const char* filename, int32_t* sampling_rate);
+ bool LoadWav2Char(const char* filename, int32_t* sampling_rate);
bool LoadPcmwav(const char* buf, int n_file_len, int32_t* sampling_rate);
bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
+ bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate);
int FetchChunck(float *&dout, int len);
int Fetch(float *&dout, int &len, int &flag);
void Padding();
void Split(OfflineStream* offline_streamj);
- void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments);
+ void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
float GetTimeLen();
int GetQueueSize() { return (int)frame_queue.size(); }
+ char* GetSpeechChar(){return speech_char;}
+ int GetSpeechLen(){return speech_len;}
};
} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/include/funasrruntime.h b/funasr/runtime/onnxruntime/include/funasrruntime.h
index 5cfdb47..af430f7 100644
--- a/funasr/runtime/onnxruntime/include/funasrruntime.h
+++ b/funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -46,12 +46,6 @@
FUNASR_MODEL_PARAFORMER = 3,
}FUNASR_MODEL_TYPE;
-typedef enum
-{
- FSMN_VAD_OFFLINE=0,
- FSMN_VAD_ONLINE = 1,
-}FSMN_VAD_MODE;
-
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
// ASR
@@ -68,11 +62,12 @@
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT result);
// VAD
-_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num, FSMN_VAD_MODE mode=FSMN_VAD_OFFLINE);
+_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num);
+_FUNASRAPI FUNASR_HANDLE FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle);
// buffer
-_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000);
// file, support wav & pcm
-_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate=16000);
_FUNASRAPI std::vector<std::vector<int>>* FsmnVadGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI void FsmnVadFreeResult(FUNASR_RESULT result);
diff --git a/funasr/runtime/onnxruntime/include/vad-model.h b/funasr/runtime/onnxruntime/include/vad-model.h
index b1b1e9d..07f1833 100644
--- a/funasr/runtime/onnxruntime/include/vad-model.h
+++ b/funasr/runtime/onnxruntime/include/vad-model.h
@@ -12,14 +12,9 @@
virtual ~VadModel(){};
virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
virtual std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true)=0;
- virtual void ReadModel(const char* vad_model)=0;
- virtual void LoadConfigFromYaml(const char* filename)=0;
- virtual void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
- std::vector<float> &waves)=0;
- virtual void LoadCmvn(const char *filename)=0;
- virtual void InitCache()=0;
};
-VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num, int mode);
+VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num);
+VadModel *CreateVadModel(void* fsmnvad_handle);
} // namespace funasr
#endif
diff --git a/funasr/runtime/onnxruntime/src/CMakeLists.txt b/funasr/runtime/onnxruntime/src/CMakeLists.txt
index 341a16a..d083d8e 100644
--- a/funasr/runtime/onnxruntime/src/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/src/CMakeLists.txt
@@ -1,11 +1,8 @@
file(GLOB files1 "*.cpp")
-file(GLOB files2 "*.cc")
+set(files ${files1})
-set(files ${files1} ${files2})
-set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
-
-add_library(funasr ${files})
+add_library(funasr SHARED ${files})
if(WIN32)
set(EXTRA_LIBS pthread yaml-cpp csrc glog)
@@ -24,13 +21,3 @@
include_directories(${CMAKE_SOURCE_DIR}/include)
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
-
-add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp")
-add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp")
-add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
-add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
-target_link_libraries(funasr-onnx-offline PUBLIC funasr)
-target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
-target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
-target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
-
diff --git a/funasr/runtime/onnxruntime/src/audio.cpp b/funasr/runtime/onnxruntime/src/audio.cpp
index 6d63d67..23d0010 100644
--- a/funasr/runtime/onnxruntime/src/audio.cpp
+++ b/funasr/runtime/onnxruntime/src/audio.cpp
@@ -176,12 +176,12 @@
{
if (speech_buff != NULL) {
free(speech_buff);
-
}
-
if (speech_data != NULL) {
-
free(speech_data);
+ }
+ if (speech_char != NULL) {
+ free(speech_char);
}
}
@@ -296,8 +296,47 @@
return false;
}
-bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
+bool Audio::LoadWav2Char(const char *filename, int32_t* sampling_rate)
{
+ WaveHeader header;
+ if (speech_char != NULL) {
+ free(speech_char);
+ }
+ offset = 0;
+ std::ifstream is(filename, std::ifstream::binary);
+ is.read(reinterpret_cast<char *>(&header), sizeof(header));
+ if(!is){
+ LOG(ERROR) << "Failed to read " << filename;
+ return false;
+ }
+ if (!header.Validate()) {
+ return false;
+ }
+ header.SeekToDataChunk(is);
+ if (!is) {
+ return false;
+ }
+ if (!header.Validate()) {
+ return false;
+ }
+ header.SeekToDataChunk(is);
+ if (!is) {
+ return false;
+ }
+
+ *sampling_rate = header.sample_rate;
+ // header.subchunk2_size contains the number of bytes in the data.
+ // As we assume each sample contains two bytes, so it is divided by 2 here
+ speech_len = header.subchunk2_size / 2;
+ speech_char = (char *)malloc(header.subchunk2_size);
+ memset(speech_char, 0, header.subchunk2_size);
+ is.read(speech_char, header.subchunk2_size);
+
+ return true;
+}
+
+bool Audio::LoadWav(const char* buf, int n_file_len, int32_t* sampling_rate)
+{
WaveHeader header;
if (speech_data != NULL) {
free(speech_data);
@@ -441,6 +480,33 @@
}
+bool Audio::LoadPcmwav2Char(const char* filename, int32_t* sampling_rate)
+{
+ if (speech_char != NULL) {
+ free(speech_char);
+ }
+ offset = 0;
+
+ FILE* fp;
+ fp = fopen(filename, "rb");
+ if (fp == nullptr)
+ {
+ LOG(ERROR) << "Failed to read " << filename;
+ return false;
+ }
+ fseek(fp, 0, SEEK_END);
+ uint32_t n_file_len = ftell(fp);
+ fseek(fp, 0, SEEK_SET);
+
+ speech_len = (n_file_len) / 2;
+ speech_char = (char *)malloc(n_file_len);
+ memset(speech_char, 0, n_file_len);
+ fread(speech_char, sizeof(int16_t), n_file_len/2, fp);
+ fclose(fp);
+
+ return true;
+}
+
int Audio::FetchChunck(float *&dout, int len)
{
if (offset >= speech_align_len) {
@@ -541,7 +607,7 @@
}
-void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments)
+void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished)
{
AudioFrame *frame;
@@ -552,7 +618,7 @@
frame = NULL;
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
- vad_segments = vad_obj->Infer(pcm_data);
+ vad_segments = vad_obj->Infer(pcm_data, input_finished);
}
} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp
new file mode 100644
index 0000000..0346916
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp
@@ -0,0 +1,198 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#include <fstream>
+#include "precomp.h"
+
+namespace funasr {
+
+void FsmnVadOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
+ std::vector<float> &waves) {
+ knf::OnlineFbank fbank(fbank_opts_);
+ // cache merge
+ waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
+ int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
+ // Send the audio after the last frame shift position to the cache
+ input_cache_.clear();
+ input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
+ if (frame_number == 0) {
+ return;
+ }
+ // Delete audio that haven't undergone fbank processing
+ waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
+
+ std::vector<float> buf(waves.size());
+ for (int32_t i = 0; i != waves.size(); ++i) {
+ buf[i] = waves[i] * 32768;
+ }
+ fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
+ // fbank.AcceptWaveform(sample_rate, &waves[0], waves.size());
+ int32_t frames = fbank.NumFramesReady();
+ for (int32_t i = 0; i != frames; ++i) {
+ const float *frame = fbank.GetFrame(i);
+ vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
+ vad_feats.emplace_back(frame_vector);
+ }
+}
+
+void FsmnVadOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &vad_feats,
+ vector<float> &waves, bool input_finished) {
+ FbankKaldi(sample_rate, vad_feats, waves);
+ // cache deal & online lfr,cmvn
+ if (vad_feats.size() > 0) {
+ if (!reserve_waveforms_.empty()) {
+ waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
+ }
+ if (lfr_splice_cache_.empty()) {
+ for (int i = 0; i < (lfr_m - 1) / 2; i++) {
+ lfr_splice_cache_.emplace_back(vad_feats[0]);
+ }
+ }
+ if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
+ vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
+ int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
+ int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
+ int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats, input_finished);
+ int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
+ reserve_waveforms_.clear();
+ reserve_waveforms_.insert(reserve_waveforms_.begin(),
+ waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
+ waves.begin() + frame_from_waves * frame_shift_sample_length_);
+ int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
+ waves.erase(waves.begin() + sample_length, waves.end());
+ } else {
+ reserve_waveforms_.clear();
+ reserve_waveforms_.insert(reserve_waveforms_.begin(),
+ waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
+ lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end());
+ }
+ } else {
+ if (input_finished) {
+ if (!reserve_waveforms_.empty()) {
+ waves = reserve_waveforms_;
+ }
+ vad_feats = lfr_splice_cache_;
+ OnlineLfrCmvn(vad_feats, input_finished);
+ }
+ }
+ if(input_finished){
+ Reset();
+ ResetCache();
+ }
+}
+
+int FsmnVadOnline::OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished) {
+ vector<vector<float>> out_feats;
+ int T = vad_feats.size();
+ int T_lrf = ceil((T - (lfr_m - 1) / 2) / lfr_n);
+ int lfr_splice_frame_idxs = T_lrf;
+ vector<float> p;
+ for (int i = 0; i < T_lrf; i++) {
+ if (lfr_m <= T - i * lfr_n) {
+ for (int j = 0; j < lfr_m; j++) {
+ p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
+ }
+ out_feats.emplace_back(p);
+ p.clear();
+ } else {
+ if (input_finished) {
+ int num_padding = lfr_m - (T - i * lfr_n);
+ for (int j = 0; j < (vad_feats.size() - i * lfr_n); j++) {
+ p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
+ }
+ for (int j = 0; j < num_padding; j++) {
+ p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
+ }
+ out_feats.emplace_back(p);
+ } else {
+ lfr_splice_frame_idxs = i;
+ break;
+ }
+ }
+ }
+ lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n);
+ lfr_splice_cache_.clear();
+ lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end());
+
+ // Apply cmvn
+ for (auto &out_feat: out_feats) {
+ for (int j = 0; j < means_list_.size(); j++) {
+ out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
+ }
+ }
+ vad_feats = out_feats;
+ return lfr_splice_frame_idxs;
+}
+
+std::vector<std::vector<int>>
+FsmnVadOnline::Infer(std::vector<float> &waves, bool input_finished) {
+ std::vector<std::vector<float>> vad_feats;
+ std::vector<std::vector<float>> vad_probs;
+ ExtractFeats(vad_sample_rate_, vad_feats, waves, input_finished);
+ fsmnvad_handle_->Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
+
+ std::vector<std::vector<int>> vad_segments;
+ vad_segments = vad_scorer(vad_probs, waves, input_finished, true, vad_silence_duration_, vad_max_len_,
+ vad_speech_noise_thres_, vad_sample_rate_);
+ return vad_segments;
+}
+
+void FsmnVadOnline::InitCache(){
+ std::vector<float> cache_feats(128 * 19 * 1, 0);
+ for (int i=0;i<4;i++){
+ in_cache_.emplace_back(cache_feats);
+ }
+};
+
+void FsmnVadOnline::Reset(){
+ in_cache_.clear();
+ InitCache();
+};
+
+void FsmnVadOnline::Test() {
+}
+
+void FsmnVadOnline::InitOnline(std::shared_ptr<Ort::Session> &vad_session,
+ Ort::Env &env,
+ std::vector<const char *> &vad_in_names,
+ std::vector<const char *> &vad_out_names,
+ knf::FbankOptions &fbank_opts,
+ std::vector<float> &means_list,
+ std::vector<float> &vars_list,
+ int vad_sample_rate,
+ int vad_silence_duration,
+ int vad_max_len,
+ double vad_speech_noise_thres) {
+ vad_session_ = vad_session;
+ vad_in_names_ = vad_in_names;
+ vad_out_names_ = vad_out_names;
+ fbank_opts_ = fbank_opts;
+ means_list_ = means_list;
+ vars_list_ = vars_list;
+ vad_sample_rate_ = vad_sample_rate;
+ vad_silence_duration_ = vad_silence_duration;
+ vad_max_len_ = vad_max_len;
+ vad_speech_noise_thres_ = vad_speech_noise_thres;
+}
+
+FsmnVadOnline::~FsmnVadOnline() {
+}
+
+FsmnVadOnline::FsmnVadOnline(FsmnVad* fsmnvad_handle):fsmnvad_handle_(std::move(fsmnvad_handle)),session_options_{}{
+ InitCache();
+ InitOnline(fsmnvad_handle_->vad_session_,
+ fsmnvad_handle_->env_,
+ fsmnvad_handle_->vad_in_names_,
+ fsmnvad_handle_->vad_out_names_,
+ fsmnvad_handle_->fbank_opts_,
+ fsmnvad_handle_->means_list_,
+ fsmnvad_handle_->vars_list_,
+ fsmnvad_handle_->vad_sample_rate_,
+ fsmnvad_handle_->vad_silence_duration_,
+ fsmnvad_handle_->vad_max_len_,
+ fsmnvad_handle_->vad_speech_noise_thres_);
+}
+
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad-online.h b/funasr/runtime/onnxruntime/src/fsmn-vad-online.h
new file mode 100644
index 0000000..4d429b6
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad-online.h
@@ -0,0 +1,88 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#pragma once
+#include "precomp.h"
+
+namespace funasr {
+class FsmnVadOnline : public VadModel {
+/**
+ * Author: Speech Lab of DAMO Academy, Alibaba Group
+ * Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ * https://arxiv.org/abs/1803.05030
+*/
+
+public:
+ explicit FsmnVadOnline(FsmnVad* fsmnvad_handle);
+ ~FsmnVadOnline();
+ void Test();
+ std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
+ void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
+ void Reset();
+
+private:
+ E2EVadModel vad_scorer = E2EVadModel();
+ // std::unique_ptr<FsmnVad> fsmnvad_handle_;
+ FsmnVad* fsmnvad_handle_ = nullptr;
+
+ void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
+ std::vector<float> &waves);
+ int OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished);
+ void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num){}
+ void InitCache();
+ void InitOnline(std::shared_ptr<Ort::Session> &vad_session,
+ Ort::Env &env,
+ std::vector<const char *> &vad_in_names,
+ std::vector<const char *> &vad_out_names,
+ knf::FbankOptions &fbank_opts,
+ std::vector<float> &means_list,
+ std::vector<float> &vars_list,
+ int vad_sample_rate,
+ int vad_silence_duration,
+ int vad_max_len,
+ double vad_speech_noise_thres);
+
+ static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
+ int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
+ if (frame_num >= 1 && sample_length >= frame_sample_length)
+ return frame_num;
+ else
+ return 0;
+ }
+ void ResetCache() {
+ reserve_waveforms_.clear();
+ input_cache_.clear();
+ lfr_splice_cache_.clear();
+ }
+
+ // from fsmnvad_handle_
+ std::shared_ptr<Ort::Session> vad_session_ = nullptr;
+ Ort::Env env_;
+ Ort::SessionOptions session_options_;
+ std::vector<const char *> vad_in_names_;
+ std::vector<const char *> vad_out_names_;
+ knf::FbankOptions fbank_opts_;
+ std::vector<float> means_list_;
+ std::vector<float> vars_list_;
+
+ std::vector<std::vector<float>> in_cache_;
+ // The reserved waveforms by fbank
+ std::vector<float> reserve_waveforms_;
+ // waveforms reserved after last shift position
+ std::vector<float> input_cache_;
+ // lfr reserved cache
+ std::vector<std::vector<float>> lfr_splice_cache_;
+
+ int vad_sample_rate_ = MODEL_SAMPLE_RATE;
+ int vad_silence_duration_ = VAD_SILENCE_DURATION;
+ int vad_max_len_ = VAD_MAX_LEN;
+ double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
+ int lfr_m = VAD_LFR_M;
+ int lfr_n = VAD_LFR_N;
+ int frame_sample_length_ = vad_sample_rate_ / 1000 * 25;;
+ int frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10;
+};
+
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
index 516dc88..697828b 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -37,14 +37,14 @@
this->vad_max_len_ = post_conf["max_single_segment_time"].as<int>();
this->vad_speech_noise_thres_ = post_conf["speech_noise_thres"].as<double>();
- fbank_opts.frame_opts.dither = frontend_conf["dither"].as<float>();
- fbank_opts.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
- fbank_opts.frame_opts.samp_freq = (float)vad_sample_rate_;
- fbank_opts.frame_opts.window_type = frontend_conf["window"].as<string>();
- fbank_opts.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
- fbank_opts.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
- fbank_opts.energy_floor = 0;
- fbank_opts.mel_opts.debug_mel = false;
+ fbank_opts_.frame_opts.dither = frontend_conf["dither"].as<float>();
+ fbank_opts_.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
+ fbank_opts_.frame_opts.samp_freq = (float)vad_sample_rate_;
+ fbank_opts_.frame_opts.window_type = frontend_conf["window"].as<string>();
+ fbank_opts_.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
+ fbank_opts_.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
+ fbank_opts_.energy_floor = 0;
+ fbank_opts_.mel_opts.debug_mel = false;
}catch(exception const &e){
LOG(ERROR) << "Error when load argument from vad config YAML.";
exit(-1);
@@ -55,6 +55,7 @@
try {
vad_session_ = std::make_shared<Ort::Session>(
env_, vad_model, session_options_);
+ LOG(INFO) << "Successfully load model from " << vad_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load vad onnx model: " << e.what();
exit(0);
@@ -109,7 +110,9 @@
void FsmnVad::Forward(
const std::vector<std::vector<float>> &chunk_feats,
- std::vector<std::vector<float>> *out_prob) {
+ std::vector<std::vector<float>> *out_prob,
+ std::vector<std::vector<float>> *in_cache,
+ bool is_final) {
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
@@ -132,9 +135,9 @@
// 4 caches
// cache node {batch,128,19,1}
const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
- for (int i = 0; i < in_cache_.size(); i++) {
+ for (int i = 0; i < in_cache->size(); i++) {
vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
- memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
+ memory_info, (*in_cache)[i].data(), (*in_cache)[i].size(), cache_feats_shape, 4)));
}
// 4. Onnx infer
@@ -162,15 +165,17 @@
}
// get 4 caches outputs,each size is 128*19
- // for (int i = 1; i < 5; i++) {
- // float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
- // memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
- // }
+ if(!is_final){
+ for (int i = 1; i < 5; i++) {
+ float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
+ memcpy((*in_cache)[i-1].data(), data, sizeof(float) * 128*19);
+ }
+ }
}
void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
std::vector<float> &waves) {
- knf::OnlineFbank fbank(fbank_opts);
+ knf::OnlineFbank fbank(fbank_opts_);
std::vector<float> buf(waves.size());
for (int32_t i = 0; i != waves.size(); ++i) {
@@ -180,7 +185,7 @@
int32_t frames = fbank.NumFramesReady();
for (int32_t i = 0; i != frames; ++i) {
const float *frame = fbank.GetFrame(i);
- std::vector<float> frame_vector(frame, frame + fbank_opts.mel_opts.num_bins);
+ std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
vad_feats.emplace_back(frame_vector);
}
}
@@ -205,7 +210,7 @@
vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
if (means_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < means_lines.size() - 1; j++) {
- means_list.push_back(stof(means_lines[j]));
+ means_list_.push_back(stof(means_lines[j]));
}
continue;
}
@@ -216,8 +221,8 @@
vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
if (vars_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < vars_lines.size() - 1; j++) {
- // vars_list.push_back(stof(vars_lines[j])*scale);
- vars_list.push_back(stof(vars_lines[j]));
+ // vars_list_.push_back(stof(vars_lines[j])*scale);
+ vars_list_.push_back(stof(vars_lines[j]));
}
continue;
}
@@ -263,8 +268,8 @@
}
// Apply cmvn
for (auto &out_feat: out_feats) {
- for (int j = 0; j < means_list.size(); j++) {
- out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j];
+ for (int j = 0; j < means_list_.size(); j++) {
+ out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
}
}
vad_feats = out_feats;
@@ -276,7 +281,7 @@
std::vector<std::vector<float>> vad_probs;
FbankKaldi(vad_sample_rate_, vad_feats, waves);
LfrCmvn(vad_feats);
- Forward(vad_feats, &vad_probs);
+ Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
E2EVadModel vad_scorer = E2EVadModel();
std::vector<std::vector<int>> vad_segments;
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.h b/funasr/runtime/onnxruntime/src/fsmn-vad.h
index a8ec4ce..adceb1f 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.h
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.h
@@ -22,7 +22,30 @@
void Test();
void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num);
std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true);
+ void Forward(
+ const std::vector<std::vector<float>> &chunk_feats,
+ std::vector<std::vector<float>> *out_prob,
+ std::vector<std::vector<float>> *in_cache,
+ bool is_final);
void Reset();
+
+ std::shared_ptr<Ort::Session> vad_session_ = nullptr;
+ Ort::Env env_;
+ Ort::SessionOptions session_options_;
+ std::vector<const char *> vad_in_names_;
+ std::vector<const char *> vad_out_names_;
+ std::vector<std::vector<float>> in_cache_;
+
+ knf::FbankOptions fbank_opts_;
+ std::vector<float> means_list_;
+ std::vector<float> vars_list_;
+
+ int vad_sample_rate_ = MODEL_SAMPLE_RATE;
+ int vad_silence_duration_ = VAD_SILENCE_DURATION;
+ int vad_max_len_ = VAD_MAX_LEN;
+ double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
+ int lfr_m = VAD_LFR_M;
+ int lfr_n = VAD_LFR_N;
private:
@@ -37,31 +60,9 @@
std::vector<float> &waves);
void LfrCmvn(std::vector<std::vector<float>> &vad_feats);
-
- void Forward(
- const std::vector<std::vector<float>> &chunk_feats,
- std::vector<std::vector<float>> *out_prob);
-
void LoadCmvn(const char *filename);
void InitCache();
- std::shared_ptr<Ort::Session> vad_session_ = nullptr;
- Ort::Env env_;
- Ort::SessionOptions session_options_;
- std::vector<const char *> vad_in_names_;
- std::vector<const char *> vad_out_names_;
- std::vector<std::vector<float>> in_cache_;
-
- knf::FbankOptions fbank_opts;
- std::vector<float> means_list;
- std::vector<float> vars_list;
-
- int vad_sample_rate_ = MODEL_SAMPLE_RATE;
- int vad_silence_duration_ = VAD_SILENCE_DURATION;
- int vad_max_len_ = VAD_MAX_LEN;
- double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
- int lfr_m = VAD_LFR_M;
- int lfr_n = VAD_LFR_N;
};
} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
index adef504..f504b39 100644
--- a/funasr/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -11,9 +11,15 @@
return mm;
}
- _FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num, FSMN_VAD_MODE mode)
+ _FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num)
{
- funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num, mode);
+ funasr::VadModel* mm = funasr::CreateVadModel(model_path, thread_num);
+ return mm;
+ }
+
+ _FUNASRAPI FUNASR_HANDLE FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle)
+ {
+ funasr::VadModel* mm = funasr::CreateVadModel(fsmnvad_handle);
return mm;
}
@@ -96,7 +102,7 @@
}
// APIs for VAD Infer
- _FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
+ _FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate)
{
funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
if (!vad_obj)
@@ -110,13 +116,13 @@
p_result->snippet_time = audio.GetTimeLen();
vector<std::vector<int>> vad_segments;
- audio.Split(vad_obj, vad_segments);
+ audio.Split(vad_obj, vad_segments, input_finished);
p_result->segments = new vector<std::vector<int>>(vad_segments);
return p_result;
}
- _FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, FSMN_VAD_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
+ _FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate)
{
funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
if (!vad_obj)
@@ -139,7 +145,7 @@
p_result->snippet_time = audio.GetTimeLen();
vector<std::vector<int>> vad_segments;
- audio.Split(vad_obj, vad_segments);
+ audio.Split(vad_obj, vad_segments, true);
p_result->segments = new vector<std::vector<int>>(vad_segments);
return p_result;
diff --git a/funasr/runtime/onnxruntime/src/online-feature.cpp b/funasr/runtime/onnxruntime/src/online-feature.cpp
deleted file mode 100644
index a21589c..0000000
--- a/funasr/runtime/onnxruntime/src/online-feature.cpp
+++ /dev/null
@@ -1,137 +0,0 @@
-/**
- * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
- * MIT License (https://opensource.org/licenses/MIT)
- * Contributed by zhuzizyf(China Telecom).
-*/
-
-#include "online-feature.h"
-#include <utility>
-
-namespace funasr {
-OnlineFeature::OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m, int lfr_n,
- std::vector<std::vector<float>> cmvns)
- : sample_rate_(sample_rate),
- fbank_opts_(std::move(fbank_opts)),
- lfr_m_(lfr_m),
- lfr_n_(lfr_n),
- cmvns_(std::move(cmvns)) {
- frame_sample_length_ = sample_rate_ / 1000 * 25;;
- frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
-}
-
-void OnlineFeature::ExtractFeats(vector<std::vector<float>> &vad_feats,
- vector<float> waves, bool input_finished) {
- input_finished_ = input_finished;
- OnlineFbank(vad_feats, waves);
- // cache deal & online lfr,cmvn
- if (vad_feats.size() > 0) {
- if (!reserve_waveforms_.empty()) {
- waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
- }
- if (lfr_splice_cache_.empty()) {
- for (int i = 0; i < (lfr_m_ - 1) / 2; i++) {
- lfr_splice_cache_.emplace_back(vad_feats[0]);
- }
- }
- if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m_) {
- vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
- int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
- int minus_frame = reserve_waveforms_.empty() ? (lfr_m_ - 1) / 2 : 0;
- int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats);
- int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
- reserve_waveforms_.clear();
- reserve_waveforms_.insert(reserve_waveforms_.begin(),
- waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
- waves.begin() + frame_from_waves * frame_shift_sample_length_);
- int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
- waves.erase(waves.begin() + sample_length, waves.end());
- } else {
- reserve_waveforms_.clear();
- reserve_waveforms_.insert(reserve_waveforms_.begin(),
- waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
- lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end());
- }
-
- } else {
- if (input_finished_) {
- if (!reserve_waveforms_.empty()) {
- waves = reserve_waveforms_;
- }
- vad_feats = lfr_splice_cache_;
- OnlineLfrCmvn(vad_feats);
- ResetCache();
- }
- }
-
-}
-
-int OnlineFeature::OnlineLfrCmvn(vector<vector<float>> &vad_feats) {
- vector<vector<float>> out_feats;
- int T = vad_feats.size();
- int T_lrf = ceil((T - (lfr_m_ - 1) / 2) / lfr_n_);
- int lfr_splice_frame_idxs = T_lrf;
- vector<float> p;
- for (int i = 0; i < T_lrf; i++) {
- if (lfr_m_ <= T - i * lfr_n_) {
- for (int j = 0; j < lfr_m_; j++) {
- p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
- }
- out_feats.emplace_back(p);
- p.clear();
- } else {
- if (input_finished_) {
- int num_padding = lfr_m_ - (T - i * lfr_n_);
- for (int j = 0; j < (vad_feats.size() - i * lfr_n_); j++) {
- p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
- }
- for (int j = 0; j < num_padding; j++) {
- p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
- }
- out_feats.emplace_back(p);
- } else {
- lfr_splice_frame_idxs = i;
- break;
- }
- }
- }
- lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n_);
- lfr_splice_cache_.clear();
- lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end());
-
- // Apply cmvn
- for (auto &out_feat: out_feats) {
- for (int j = 0; j < cmvns_[0].size(); j++) {
- out_feat[j] = (out_feat[j] + cmvns_[0][j]) * cmvns_[1][j];
- }
- }
- vad_feats = out_feats;
- return lfr_splice_frame_idxs;
-}
-
-void OnlineFeature::OnlineFbank(vector<std::vector<float>> &vad_feats,
- vector<float> &waves) {
-
- knf::OnlineFbank fbank(fbank_opts_);
- // cache merge
- waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
- int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
- // Send the audio after the last frame shift position to the cache
- input_cache_.clear();
- input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
- if (frame_number == 0) {
- return;
- }
- // Delete audio that haven't undergone fbank processing
- waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
-
- fbank.AcceptWaveform(sample_rate_, &waves[0], waves.size());
- int32_t frames = fbank.NumFramesReady();
- for (int32_t i = 0; i != frames; ++i) {
- const float *frame = fbank.GetFrame(i);
- vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
- vad_feats.emplace_back(frame_vector);
- }
-
-}
-
-} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/online-feature.h b/funasr/runtime/onnxruntime/src/online-feature.h
deleted file mode 100644
index 16e6e4b..0000000
--- a/funasr/runtime/onnxruntime/src/online-feature.h
+++ /dev/null
@@ -1,58 +0,0 @@
-/**
- * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
- * MIT License (https://opensource.org/licenses/MIT)
- * Contributed by zhuzizyf(China Telecom).
-*/
-#pragma once
-#include <vector>
-#include "precomp.h"
-
-using namespace std;
-namespace funasr {
-class OnlineFeature {
-
-public:
- OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m_, int lfr_n_,
- std::vector<std::vector<float>> cmvns_);
-
- void ExtractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
-
-private:
- void OnlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
- int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
-
- static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
- int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
- if (frame_num >= 1 && sample_length >= frame_sample_length)
- return frame_num;
- else
- return 0;
- }
-
- void ResetCache() {
- reserve_waveforms_.clear();
- input_cache_.clear();
- lfr_splice_cache_.clear();
- input_finished_ = false;
-
- }
-
- knf::FbankOptions fbank_opts_;
- // The reserved waveforms by fbank
- std::vector<float> reserve_waveforms_;
- // waveforms reserved after last shift position
- std::vector<float> input_cache_;
- // lfr reserved cache
- std::vector<std::vector<float>> lfr_splice_cache_;
- std::vector<std::vector<float>> cmvns_;
-
- int sample_rate_ = 16000;
- int frame_sample_length_ = sample_rate_ / 1000 * 25;;
- int frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
- int lfr_m_;
- int lfr_n_;
- bool input_finished_ = false;
-
-};
-
-} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h
index 533c16f..9df0977 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.h
+++ b/funasr/runtime/onnxruntime/src/paraformer.h
@@ -18,7 +18,7 @@
//std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions fbank_opts;
- Vocab* vocab;
+ Vocab* vocab = nullptr;
vector<float> means_list;
vector<float> vars_list;
const float scale = 22.6274169979695;
@@ -30,7 +30,7 @@
void ApplyCmvn(vector<float> *v);
string GreedySearch( float* in, int n_len, int64_t token_nums);
- std::shared_ptr<Ort::Session> m_session;
+ std::shared_ptr<Ort::Session> m_session = nullptr;
Ort::Env env_;
Ort::SessionOptions session_options;
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index e607dbf..838dddc 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -36,8 +36,9 @@
#include "offline-stream.h"
#include "tokenizer.h"
#include "ct-transformer.h"
-#include "fsmn-vad.h"
#include "e2e-vad.h"
+#include "fsmn-vad.h"
+#include "fsmn-vad-online.h"
#include "vocab.h"
#include "audio.h"
#include "tensor.h"
diff --git a/funasr/runtime/onnxruntime/src/vad-model.cpp b/funasr/runtime/onnxruntime/src/vad-model.cpp
index 336758f..c164c3e 100644
--- a/funasr/runtime/onnxruntime/src/vad-model.cpp
+++ b/funasr/runtime/onnxruntime/src/vad-model.cpp
@@ -1,14 +1,10 @@
#include "precomp.h"
namespace funasr {
-VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num, int mode)
+VadModel *CreateVadModel(std::map<std::string, std::string>& model_path, int thread_num)
{
VadModel *mm;
- if(mode == FSMN_VAD_OFFLINE){
- mm = new FsmnVad();
- }else{
- LOG(ERROR)<<"Online fsmn vad not imp!";
- }
+ mm = new FsmnVad();
string vad_model_path;
string vad_cmvn_path;
@@ -25,4 +21,11 @@
return mm;
}
+VadModel *CreateVadModel(void* fsmnvad_handle)
+{
+ VadModel *mm;
+ mm = new FsmnVadOnline((FsmnVad*)fsmnvad_handle);
+ return mm;
+}
+
} // namespace funasr
\ No newline at end of file
--
Gitblit v1.9.1