funasr/runtime/onnxruntime/include/Audio.h
@@ -5,6 +5,7 @@ #include <ComDefine.h> #include <queue> #include <stdint.h> #include "Model.h" #ifndef model_sample_rate #define model_sample_rate 16000 @@ -27,7 +28,7 @@ ~AudioFrame(); int set_start(int val); int set_end(int val, int max_len); int set_end(int val); int get_start(); int get_len(); int disp(); @@ -57,7 +58,7 @@ int fetch_chunck(float *&dout, int len); int fetch(float *&dout, int &len, int &flag); void padding(); void split(); void split(Model* pRecogObj); float get_time_len(); int get_queue_size() { return (int)frame_queue.size(); } funasr/runtime/onnxruntime/include/Model.h
@@ -11,7 +11,8 @@ virtual std::string forward_chunk(float *din, int len, int flag) = 0; virtual std::string forward(float *din, int len, int flag) = 0; virtual std::string rescoring() = 0; virtual std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data)=0; }; Model *create_model(const char *path,int nThread=0,bool quantize=false); Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false); #endif funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -48,18 +48,18 @@ typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step. // APIs for qmasr _FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThread, bool quantize); // APIs for funasr _FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThread, bool quantize=false, bool use_vad=false); // if not give a fnCallback ,it should be NULL _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback); _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false); _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback); _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false); _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback); _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false); _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback); _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false); _FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex); funasr/runtime/onnxruntime/src/Audio.cpp
@@ -134,19 +134,10 @@ return start; }; int AudioFrame::set_end(int val, int max_len) int AudioFrame::set_end(int val) { float num_samples = val - start; float frame_length = 400; float frame_shift = 160; float num_new_samples = ceil((num_samples - frame_length) / frame_shift) * frame_shift + frame_length; end = start + num_new_samples; len = (int)num_new_samples; if (end > max_len) printf("frame end > max_len!!!!!!!\n"); end = val; len = end - start; return end; }; @@ -473,7 +464,6 @@ void Audio::padding() { float num_samples = speech_len; float frame_length = 400; float frame_shift = 160; @@ -509,71 +499,27 @@ delete frame; } #define UNTRIGGERED 0 #define TRIGGERED 1 #define SPEECH_LEN_5S (16000 * 5) #define SPEECH_LEN_10S (16000 * 10) #define SPEECH_LEN_20S (16000 * 20) #define SPEECH_LEN_30S (16000 * 30) /* void Audio::split() void Audio::split(Model* pRecogObj) { VadInst *handle = WebRtcVad_Create(); WebRtcVad_Init(handle); WebRtcVad_set_mode(handle, 2); int window_size = 10; AudioWindow audiowindow(window_size); int status = UNTRIGGERED; int offset = 0; int fs = 16000; int step = 480; AudioFrame *frame; frame = frame_queue.front(); frame_queue.pop(); int sp_len = frame->get_len(); delete frame; frame = NULL; while (offset < speech_len - step) { int n = WebRtcVad_Process(handle, fs, speech_buff + offset, step); if (status == UNTRIGGERED && audiowindow.put(n) >= window_size - 1) { frame = new AudioFrame(); int start = offset - step * (window_size - 1); frame->set_start(start); status = TRIGGERED; } else if (status == TRIGGERED) { int win_weight = audiowindow.put(n); int voice_len = (offset - frame->get_start()); int gap = 0; if (voice_len < SPEECH_LEN_5S) { offset += step; continue; } else if (voice_len < SPEECH_LEN_10S) { gap = 1; } else if (voice_len < SPEECH_LEN_20S) { gap = window_size / 5; } else { gap = window_size / 2; } if (win_weight < gap) { status = UNTRIGGERED; offset = frame->set_end(offset, speech_align_len); frame_queue.push(frame); frame = NULL; } } offset += step; } if (frame != NULL) { frame->set_end(speech_len, speech_align_len); std::vector<float> pcm_data(speech_data, speech_data+sp_len); vector<std::vector<int>> vad_segments = pRecogObj->vad_seg(pcm_data); int seg_sample = model_sample_rate/1000; for(vector<int> segment:vad_segments) { frame = new AudioFrame(); int start = segment[0]*seg_sample; int end = segment[1]*seg_sample; frame->set_start(start); frame->set_end(end); frame_queue.push(frame); frame = NULL; } WebRtcVad_Free(handle); } */ funasr/runtime/onnxruntime/src/FsmnVad.cc
New file @@ -0,0 +1,268 @@ // // Created by root on 4/9/23. // #include <fstream> #include "FsmnVad.h" #include "precomp.h" //#include "glog/logging.h" void FsmnVad::init_vad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len, float vad_speech_noise_thres) { session_options_.SetIntraOpNumThreads(1); session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL); session_options_.DisableCpuMemArena(); this->vad_sample_rate_ = vad_sample_rate; this->vad_silence_duration_=vad_silence_duration; this->vad_max_len_=vad_max_len; this->vad_speech_noise_thres_=vad_speech_noise_thres; read_model(vad_model); load_cmvn(vad_cmvn.c_str()); fbank_opts.frame_opts.dither = 0; fbank_opts.mel_opts.num_bins = 80; fbank_opts.frame_opts.samp_freq = vad_sample_rate; fbank_opts.frame_opts.window_type = "hamming"; fbank_opts.frame_opts.frame_shift_ms = 10; fbank_opts.frame_opts.frame_length_ms = 25; fbank_opts.energy_floor = 0; fbank_opts.mel_opts.debug_mel = false; } void FsmnVad::read_model(const std::string &vad_model) { try { vad_session_ = std::make_shared<Ort::Session>( env_, vad_model.c_str(), session_options_); } catch (std::exception const &e) { //LOG(ERROR) << "Error when load onnx model: " << e.what(); exit(0); } //LOG(INFO) << "vad onnx:"; GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_); } void FsmnVad::GetInputOutputInfo( const std::shared_ptr<Ort::Session> &session, std::vector<const char *> *in_names, std::vector<const char *> *out_names) { Ort::AllocatorWithDefaultOptions allocator; // Input info int num_nodes = session->GetInputCount(); in_names->resize(num_nodes); for (int i = 0; i < num_nodes; ++i) { std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetInputNameAllocated(i, allocator); Ort::TypeInfo type_info = session->GetInputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); ONNXTensorElementDataType type = tensor_info.GetElementType(); std::vector<int64_t> node_dims = tensor_info.GetShape(); std::stringstream shape; for (auto j: node_dims) { shape << j; shape << " "; } // LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type // << " dims=" << shape.str(); (*in_names)[i] = name.get(); name.release(); } // Output info num_nodes = session->GetOutputCount(); out_names->resize(num_nodes); for (int i = 0; i < num_nodes; ++i) { std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetOutputNameAllocated(i, allocator); Ort::TypeInfo type_info = session->GetOutputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); ONNXTensorElementDataType type = tensor_info.GetElementType(); std::vector<int64_t> node_dims = tensor_info.GetShape(); std::stringstream shape; for (auto j: node_dims) { shape << j; shape << " "; } // LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type // << " dims=" << shape.str(); (*out_names)[i] = name.get(); name.release(); } } void FsmnVad::Forward( const std::vector<std::vector<float>> &chunk_feats, std::vector<std::vector<float>> *out_prob) { Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); int num_frames = chunk_feats.size(); const int feature_dim = chunk_feats[0].size(); // 2. Generate input nodes tensor // vad node { batch,frame number,feature dim } const int64_t vad_feats_shape[3] = {1, num_frames, feature_dim}; std::vector<float> vad_feats; for (const auto &chunk_feat: chunk_feats) { vad_feats.insert(vad_feats.end(), chunk_feat.begin(), chunk_feat.end()); } Ort::Value vad_feats_ort = Ort::Value::CreateTensor<float>( memory_info, vad_feats.data(), vad_feats.size(), vad_feats_shape, 3); // cache node {batch,128,19,1} const int64_t cache_feats_shape[4] = {1, 128, 19, 1}; std::vector<float> cache_feats(128 * 19 * 1, 0); Ort::Value cache_feats_ort = Ort::Value::CreateTensor<float>( memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4); // 3. Put nodes into onnx input vector std::vector<Ort::Value> vad_inputs; vad_inputs.emplace_back(std::move(vad_feats_ort)); // 4 caches for (int i = 0; i < 4; i++) { vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>( memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4))); } // 4. Onnx infer std::vector<Ort::Value> vad_ort_outputs; try { // VLOG(3) << "Start infer"; vad_ort_outputs = vad_session_->Run( Ort::RunOptions{nullptr}, vad_in_names_.data(), vad_inputs.data(), vad_inputs.size(), vad_out_names_.data(), vad_out_names_.size()); } catch (std::exception const &e) { // LOG(ERROR) << e.what(); return; } // 5. Change infer result to output shapes float *logp_data = vad_ort_outputs[0].GetTensorMutableData<float>(); auto type_info = vad_ort_outputs[0].GetTensorTypeAndShapeInfo(); int num_outputs = type_info.GetShape()[1]; int output_dim = type_info.GetShape()[2]; out_prob->resize(num_outputs); for (int i = 0; i < num_outputs; i++) { (*out_prob)[i].resize(output_dim); memcpy((*out_prob)[i].data(), logp_data + i * output_dim, sizeof(float) * output_dim); } } void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats, const std::vector<float> &waves) { knf::OnlineFbank fbank(fbank_opts); fbank.AcceptWaveform(sample_rate, &waves[0], waves.size()); int32_t frames = fbank.NumFramesReady(); for (int32_t i = 0; i != frames; ++i) { const float *frame = fbank.GetFrame(i); std::vector<float> frame_vector(frame, frame + fbank_opts.mel_opts.num_bins); vad_feats.emplace_back(frame_vector); } } void FsmnVad::load_cmvn(const char *filename) { using namespace std; ifstream cmvn_stream(filename); string line; while (getline(cmvn_stream, line)) { istringstream iss(line); vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}}; if (line_item[0] == "<AddShift>") { getline(cmvn_stream, line); istringstream means_lines_stream(line); vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}}; if (means_lines[0] == "<LearnRateCoef>") { for (int j = 3; j < means_lines.size() - 1; j++) { means_list.push_back(stof(means_lines[j])); } continue; } } else if (line_item[0] == "<Rescale>") { getline(cmvn_stream, line); istringstream vars_lines_stream(line); vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}}; if (vars_lines[0] == "<LearnRateCoef>") { for (int j = 3; j < vars_lines.size() - 1; j++) { // vars_list.push_back(stof(vars_lines[j])*scale); vars_list.push_back(stof(vars_lines[j])); } continue; } } } } std::vector<std::vector<float>> &FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats, int lfr_m, int lfr_n) { std::vector<std::vector<float>> out_feats; int T = vad_feats.size(); int T_lrf = ceil(1.0 * T / lfr_n); // Pad frames at start(copy first frame) for (int i = 0; i < (lfr_m - 1) / 2; i++) { vad_feats.insert(vad_feats.begin(), vad_feats[0]); } // Merge lfr_m frames as one,lfr_n frames per window T = T + (lfr_m - 1) / 2; std::vector<float> p; for (int i = 0; i < T_lrf; i++) { if (lfr_m <= T - i * lfr_n) { for (int j = 0; j < lfr_m; j++) { p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end()); } out_feats.emplace_back(p); p.clear(); } else { // Fill to lfr_m frames at last window if less than lfr_m frames (copy last frame) int num_padding = lfr_m - (T - i * lfr_n); for (int j = 0; j < (vad_feats.size() - i * lfr_n); j++) { p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end()); } for (int j = 0; j < num_padding; j++) { p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end()); } out_feats.emplace_back(p); } } // Apply cmvn for (auto &out_feat: out_feats) { for (int j = 0; j < means_list.size(); j++) { out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j]; } } vad_feats = out_feats; return vad_feats; } std::vector<std::vector<int>> FsmnVad::infer(const std::vector<float> &waves) { std::vector<std::vector<float>> vad_feats; std::vector<std::vector<float>> vad_probs; FbankKaldi(vad_sample_rate_, vad_feats, waves); vad_feats = LfrCmvn(vad_feats, 5, 1); Forward(vad_feats, &vad_probs); E2EVadModel vad_scorer = E2EVadModel(); std::vector<std::vector<int>> vad_segments; vad_segments = vad_scorer(vad_probs, waves, true, vad_silence_duration_, vad_max_len_, vad_speech_noise_thres_, vad_sample_rate_); return vad_segments; } void FsmnVad::test() { } FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} { } funasr/runtime/onnxruntime/src/FsmnVad.h
New file @@ -0,0 +1,57 @@ // // Created by zyf on 4/9/23. // #ifndef VAD_SERVER_FSMNVAD_H #define VAD_SERVER_FSMNVAD_H #include "e2e_vad.h" #include "onnxruntime_cxx_api.h" #include "kaldi-native-fbank/csrc/feature-fbank.h" #include "kaldi-native-fbank/csrc/online-feature.h" class FsmnVad { public: FsmnVad(); void test(); void init_vad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len, float vad_speech_noise_thres); std::vector<std::vector<int>> infer(const std::vector<float> &waves); private: void read_model(const std::string &vad_model); static void GetInputOutputInfo( const std::shared_ptr<Ort::Session> &session, std::vector<const char *> *in_names, std::vector<const char *> *out_names); void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats, const std::vector<float> &waves); std::vector<std::vector<float>> &LfrCmvn(std::vector<std::vector<float>> &vad_feats, int lfr_m, int lfr_n); void Forward( const std::vector<std::vector<float>> &chunk_feats, std::vector<std::vector<float>> *out_prob); void load_cmvn(const char *filename); std::shared_ptr<Ort::Session> vad_session_ = nullptr; Ort::Env env_; Ort::SessionOptions session_options_; std::vector<const char *> vad_in_names_; std::vector<const char *> vad_out_names_; knf::FbankOptions fbank_opts; std::vector<float> means_list; std::vector<float> vars_list; int vad_sample_rate_ = 16000; int vad_silence_duration_ = 800; int vad_max_len_ = 15000; double vad_speech_noise_thres_ = 0.9; }; #endif //VAD_SERVER_FSMNVAD_H funasr/runtime/onnxruntime/src/Model.cpp
@@ -1,10 +1,10 @@ #include "precomp.h" Model *create_model(const char *path, int nThread, bool quantize) Model *create_model(const char *path, int nThread, bool quantize, bool use_vad) { Model *mm; mm = new paraformer::ModelImp(path, nThread, quantize); mm = new paraformer::ModelImp(path, nThread, quantize, use_vad); return mm; } funasr/runtime/onnxruntime/src/e2e_vad.h
New file @@ -0,0 +1,782 @@ // // Created by root on 3/31/23. // #include <utility> #include <vector> #include <string> #include <map> #include <cmath> #include <algorithm> #include <iostream> #include <numeric> #include <cassert> enum class VadStateMachine { kVadInStateStartPointNotDetected = 1, kVadInStateInSpeechSegment = 2, kVadInStateEndPointDetected = 3 }; enum class FrameState { kFrameStateInvalid = -1, kFrameStateSpeech = 1, kFrameStateSil = 0 }; // final voice/unvoice state per frame enum class AudioChangeState { kChangeStateSpeech2Speech = 0, kChangeStateSpeech2Sil = 1, kChangeStateSil2Sil = 2, kChangeStateSil2Speech = 3, kChangeStateNoBegin = 4, kChangeStateInvalid = 5 }; enum class VadDetectMode { kVadSingleUtteranceDetectMode = 0, kVadMutipleUtteranceDetectMode = 1 }; class VADXOptions { public: int sample_rate; int detect_mode; int snr_mode; int max_end_silence_time; int max_start_silence_time; bool do_start_point_detection; bool do_end_point_detection; int window_size_ms; int sil_to_speech_time_thres; int speech_to_sil_time_thres; float speech_2_noise_ratio; int do_extend; int lookback_time_start_point; int lookahead_time_end_point; int max_single_segment_time; int nn_eval_block_size; int dcd_block_size; float snr_thres; int noise_frame_num_used_for_snr; float decibel_thres; float speech_noise_thres; float fe_prior_thres; int silence_pdf_num; std::vector<int> sil_pdf_ids; float speech_noise_thresh_low; float speech_noise_thresh_high; bool output_frame_probs; int frame_in_ms; int frame_length_ms; explicit VADXOptions( int sr = 16000, int dm = static_cast<int>(VadDetectMode::kVadMutipleUtteranceDetectMode), int sm = 0, int mset = 800, int msst = 3000, bool dspd = true, bool depd = true, int wsm = 200, int ststh = 150, int sttsh = 150, float s2nr = 1.0, int de = 1, int lbtps = 200, int latsp = 100, int mss = 15000, int nebs = 8, int dbs = 4, float st = -100.0, int nfnus = 100, float dt = -100.0, float snt = 0.9, float fept = 1e-4, int spn = 1, std::vector<int> spids = {0}, float sntl = -0.1, float snth = 0.3, bool ofp = false, int fim = 10, int flm = 25 ) : sample_rate(sr), detect_mode(dm), snr_mode(sm), max_end_silence_time(mset), max_start_silence_time(msst), do_start_point_detection(dspd), do_end_point_detection(depd), window_size_ms(wsm), sil_to_speech_time_thres(ststh), speech_to_sil_time_thres(sttsh), speech_2_noise_ratio(s2nr), do_extend(de), lookback_time_start_point(lbtps), lookahead_time_end_point(latsp), max_single_segment_time(mss), nn_eval_block_size(nebs), dcd_block_size(dbs), snr_thres(st), noise_frame_num_used_for_snr(nfnus), decibel_thres(dt), speech_noise_thres(snt), fe_prior_thres(fept), silence_pdf_num(spn), sil_pdf_ids(std::move(spids)), speech_noise_thresh_low(sntl), speech_noise_thresh_high(snth), output_frame_probs(ofp), frame_in_ms(fim), frame_length_ms(flm) {} }; class E2EVadSpeechBufWithDoa { public: int start_ms; int end_ms; std::vector<float> buffer; bool contain_seg_start_point; bool contain_seg_end_point; int doa; E2EVadSpeechBufWithDoa() : start_ms(0), end_ms(0), buffer(), contain_seg_start_point(false), contain_seg_end_point(false), doa(0) {} void Reset() { start_ms = 0; end_ms = 0; buffer.clear(); contain_seg_start_point = false; contain_seg_end_point = false; doa = 0; } }; class E2EVadFrameProb { public: double noise_prob; double speech_prob; double score; int frame_id; int frm_state; E2EVadFrameProb() : noise_prob(0.0), speech_prob(0.0), score(0.0), frame_id(0), frm_state(0) {} }; class WindowDetector { public: int window_size_ms; int sil_to_speech_time; int speech_to_sil_time; int frame_size_ms; int win_size_frame; int win_sum; std::vector<int> win_state; int cur_win_pos; FrameState pre_frame_state; FrameState cur_frame_state; int sil_to_speech_frmcnt_thres; int speech_to_sil_frmcnt_thres; int voice_last_frame_count; int noise_last_frame_count; int hydre_frame_count; WindowDetector(int window_size_ms, int sil_to_speech_time, int speech_to_sil_time, int frame_size_ms) : window_size_ms(window_size_ms), sil_to_speech_time(sil_to_speech_time), speech_to_sil_time(speech_to_sil_time), frame_size_ms(frame_size_ms), win_size_frame(window_size_ms / frame_size_ms), win_sum(0), win_state(std::vector<int>(win_size_frame, 0)), cur_win_pos(0), pre_frame_state(FrameState::kFrameStateSil), cur_frame_state(FrameState::kFrameStateSil), sil_to_speech_frmcnt_thres(sil_to_speech_time / frame_size_ms), speech_to_sil_frmcnt_thres(speech_to_sil_time / frame_size_ms), voice_last_frame_count(0), noise_last_frame_count(0), hydre_frame_count(0) {} void Reset() { cur_win_pos = 0; win_sum = 0; win_state = std::vector<int>(win_size_frame, 0); pre_frame_state = FrameState::kFrameStateSil; cur_frame_state = FrameState::kFrameStateSil; voice_last_frame_count = 0; noise_last_frame_count = 0; hydre_frame_count = 0; } int GetWinSize() { return win_size_frame; } AudioChangeState DetectOneFrame(FrameState frameState, int frame_count) { int cur_frame_state = 0; if (frameState == FrameState::kFrameStateSpeech) { cur_frame_state = 1; } else if (frameState == FrameState::kFrameStateSil) { cur_frame_state = 0; } else { return AudioChangeState::kChangeStateInvalid; } win_sum -= win_state[cur_win_pos]; win_sum += cur_frame_state; win_state[cur_win_pos] = cur_frame_state; cur_win_pos = (cur_win_pos + 1) % win_size_frame; if (pre_frame_state == FrameState::kFrameStateSil && win_sum >= sil_to_speech_frmcnt_thres) { pre_frame_state = FrameState::kFrameStateSpeech; return AudioChangeState::kChangeStateSil2Speech; } if (pre_frame_state == FrameState::kFrameStateSpeech && win_sum <= speech_to_sil_frmcnt_thres) { pre_frame_state = FrameState::kFrameStateSil; return AudioChangeState::kChangeStateSpeech2Sil; } if (pre_frame_state == FrameState::kFrameStateSil) { return AudioChangeState::kChangeStateSil2Sil; } if (pre_frame_state == FrameState::kFrameStateSpeech) { return AudioChangeState::kChangeStateSpeech2Speech; } return AudioChangeState::kChangeStateInvalid; } int FrameSizeMs() { return frame_size_ms; } }; class E2EVadModel { public: E2EVadModel() { this->vad_opts = VADXOptions(); // this->windows_detector = WindowDetector(200,150,150,10); // this->encoder = encoder; // init variables this->is_final = false; this->data_buf_start_frame = 0; this->frm_cnt = 0; this->latest_confirmed_speech_frame = 0; this->lastest_confirmed_silence_frame = -1; this->continous_silence_frame_count = 0; this->vad_state_machine = VadStateMachine::kVadInStateStartPointNotDetected; this->confirmed_start_frame = -1; this->confirmed_end_frame = -1; this->number_end_time_detected = 0; this->sil_frame = 0; this->sil_pdf_ids = this->vad_opts.sil_pdf_ids; this->noise_average_decibel = -100.0; this->pre_end_silence_detected = false; this->next_seg = true; // this->output_data_buf = []; this->output_data_buf_offset = 0; // this->frame_probs = []; this->max_end_sil_frame_cnt_thresh = this->vad_opts.max_end_silence_time - this->vad_opts.speech_to_sil_time_thres; this->speech_noise_thres = this->vad_opts.speech_noise_thres; this->max_time_out = false; // this->decibel = []; this->ResetDetection(); } std::vector<std::vector<int>> operator()(const std::vector<std::vector<float>> &score, const std::vector<float> &waveform, bool is_final = false, int max_end_sil = 800, int max_single_segment_time = 15000, float speech_noise_thres = 0.9, int sample_rate = 16000) { max_end_sil_frame_cnt_thresh = max_end_sil - vad_opts.speech_to_sil_time_thres; this->waveform = waveform; this->vad_opts.max_single_segment_time = max_single_segment_time; this->vad_opts.speech_noise_thres = speech_noise_thres; this->vad_opts.sample_rate = sample_rate; ComputeDecibel(); ComputeScores(score); if (!is_final) { DetectCommonFrames(); } else { DetectLastFrames(); } // std::vector<std::vector<int>> segments; // for (size_t batch_num = 0; batch_num < score.size(); batch_num++) { std::vector<std::vector<int>> segment_batch; if (output_data_buf.size() > 0) { for (size_t i = output_data_buf_offset; i < output_data_buf.size(); i++) { if (!output_data_buf[i].contain_seg_start_point) { continue; } if (!next_seg && !output_data_buf[i].contain_seg_end_point) { continue; } int start_ms = next_seg ? output_data_buf[i].start_ms : -1; int end_ms; if (output_data_buf[i].contain_seg_end_point) { end_ms = output_data_buf[i].end_ms; next_seg = true; output_data_buf_offset += 1; } else { end_ms = -1; next_seg = false; } std::vector<int> segment = {start_ms, end_ms}; segment_batch.push_back(segment); } } // } if (is_final) { AllResetDetection(); } return segment_batch; } private: VADXOptions vad_opts; WindowDetector windows_detector = WindowDetector(200, 150, 150, 10); bool is_final; int data_buf_start_frame; int frm_cnt; int latest_confirmed_speech_frame; int lastest_confirmed_silence_frame; int continous_silence_frame_count; VadStateMachine vad_state_machine; int confirmed_start_frame; int confirmed_end_frame; int number_end_time_detected; int sil_frame; std::vector<int> sil_pdf_ids; float noise_average_decibel; bool pre_end_silence_detected; bool next_seg; std::vector<E2EVadSpeechBufWithDoa> output_data_buf; int output_data_buf_offset; std::vector<E2EVadFrameProb> frame_probs; int max_end_sil_frame_cnt_thresh; float speech_noise_thres; std::vector<std::vector<float>> scores; bool max_time_out; std::vector<float> decibel; std::vector<float> data_buf; std::vector<float> data_buf_all; std::vector<float> waveform; void AllResetDetection() { is_final = false; data_buf_start_frame = 0; frm_cnt = 0; latest_confirmed_speech_frame = 0; lastest_confirmed_silence_frame = -1; continous_silence_frame_count = 0; vad_state_machine = VadStateMachine::kVadInStateStartPointNotDetected; confirmed_start_frame = -1; confirmed_end_frame = -1; number_end_time_detected = 0; sil_frame = 0; sil_pdf_ids = vad_opts.sil_pdf_ids; noise_average_decibel = -100.0; pre_end_silence_detected = false; next_seg = true; output_data_buf.clear(); output_data_buf_offset = 0; frame_probs.clear(); max_end_sil_frame_cnt_thresh = vad_opts.max_end_silence_time - vad_opts.speech_to_sil_time_thres; speech_noise_thres = vad_opts.speech_noise_thres; scores.clear(); max_time_out = false; decibel.clear(); data_buf.clear(); data_buf_all.clear(); waveform.clear(); ResetDetection(); } void ResetDetection() { continous_silence_frame_count = 0; latest_confirmed_speech_frame = 0; lastest_confirmed_silence_frame = -1; confirmed_start_frame = -1; confirmed_end_frame = -1; vad_state_machine = VadStateMachine::kVadInStateStartPointNotDetected; windows_detector.Reset(); sil_frame = 0; frame_probs.clear(); } void ComputeDecibel() { int frame_sample_length = int(vad_opts.frame_length_ms * vad_opts.sample_rate / 1000); int frame_shift_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000); if (data_buf_all.empty()) { data_buf_all = waveform; data_buf = data_buf_all; } else { data_buf_all.insert(data_buf_all.end(), waveform.begin(), waveform.end()); } for (int offset = 0; offset < waveform.size() - frame_sample_length + 1; offset += frame_shift_length) { float sum = 0.0; for (int i = 0; i < frame_sample_length; i++) { sum += waveform[offset + i] * waveform[offset + i]; } // float decibel = 10 * log10(sum + 0.000001); this->decibel.push_back(10 * log10(sum + 0.000001)); } } void ComputeScores(const std::vector<std::vector<float>> &scores) { vad_opts.nn_eval_block_size = scores.size(); frm_cnt += scores.size(); if (this->scores.empty()) { this->scores = scores; // the first calculation } else { this->scores.insert(this->scores.end(), scores.begin(), scores.end()); } } void PopDataBufTillFrame(int frame_idx) { while (data_buf_start_frame < frame_idx) { int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000); if (data_buf.size() >= frame_sample_length) { data_buf_start_frame += 1; data_buf.erase(data_buf.begin(), data_buf.begin() + frame_sample_length); } else { break; } } } void PopDataToOutputBuf(int start_frm, int frm_cnt, bool first_frm_is_start_point, bool last_frm_is_end_point, bool end_point_is_sent_end) { PopDataBufTillFrame(start_frm); int expected_sample_number = int(frm_cnt * vad_opts.sample_rate * vad_opts.frame_in_ms / 1000); if (last_frm_is_end_point) { int extra_sample = std::max(0, int(vad_opts.frame_length_ms * vad_opts.sample_rate / 1000 - vad_opts.sample_rate * vad_opts.frame_in_ms / 1000)); expected_sample_number += int(extra_sample); } if (end_point_is_sent_end) { expected_sample_number = std::max(expected_sample_number, int(data_buf.size())); } if (data_buf.size() < expected_sample_number) { std::cout << "error in calling pop data_buf\n"; } if (output_data_buf.size() == 0 || first_frm_is_start_point) { output_data_buf.push_back(E2EVadSpeechBufWithDoa()); output_data_buf[output_data_buf.size() - 1].Reset(); output_data_buf[output_data_buf.size() - 1].start_ms = start_frm * vad_opts.frame_in_ms; output_data_buf[output_data_buf.size() - 1].end_ms = output_data_buf[output_data_buf.size() - 1].start_ms; output_data_buf[output_data_buf.size() - 1].doa = 0; } E2EVadSpeechBufWithDoa &cur_seg = output_data_buf.back(); if (cur_seg.end_ms != start_frm * vad_opts.frame_in_ms) { std::cout << "warning\n"; } int out_pos = (int) cur_seg.buffer.size(); int data_to_pop; if (end_point_is_sent_end) { data_to_pop = expected_sample_number; } else { data_to_pop = int(frm_cnt * vad_opts.frame_in_ms * vad_opts.sample_rate / 1000); } if (data_to_pop > int(data_buf.size())) { std::cout << "VAD data_to_pop is bigger than data_buf.size()!!!\n"; data_to_pop = (int) data_buf.size(); expected_sample_number = (int) data_buf.size(); } cur_seg.doa = 0; for (int sample_cpy_out = 0; sample_cpy_out < data_to_pop; sample_cpy_out++) { cur_seg.buffer.push_back(data_buf.back()); out_pos++; } for (int sample_cpy_out = data_to_pop; sample_cpy_out < expected_sample_number; sample_cpy_out++) { cur_seg.buffer.push_back(data_buf.back()); out_pos++; } if (cur_seg.end_ms != start_frm * vad_opts.frame_in_ms) { std::cout << "Something wrong with the VAD algorithm\n"; } data_buf_start_frame += frm_cnt; cur_seg.end_ms = (start_frm + frm_cnt) * vad_opts.frame_in_ms; if (first_frm_is_start_point) { cur_seg.contain_seg_start_point = true; } if (last_frm_is_end_point) { cur_seg.contain_seg_end_point = true; } } void OnSilenceDetected(int valid_frame) { lastest_confirmed_silence_frame = valid_frame; if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) { PopDataBufTillFrame(valid_frame); } // silence_detected_callback_ // pass } void OnVoiceDetected(int valid_frame) { latest_confirmed_speech_frame = valid_frame; PopDataToOutputBuf(valid_frame, 1, false, false, false); } void OnVoiceStart(int start_frame, bool fake_result = false) { if (vad_opts.do_start_point_detection) { // pass } if (confirmed_start_frame != -1) { std::cout << "not reset vad properly\n"; } else { confirmed_start_frame = start_frame; } if (!fake_result && vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) { PopDataToOutputBuf(confirmed_start_frame, 1, true, false, false); } } void OnVoiceEnd(int end_frame, bool fake_result, bool is_last_frame) { for (int t = latest_confirmed_speech_frame + 1; t < end_frame; t++) { OnVoiceDetected(t); } if (vad_opts.do_end_point_detection) { // pass } if (confirmed_end_frame != -1) { std::cout << "not reset vad properly\n"; } else { confirmed_end_frame = end_frame; } if (!fake_result) { sil_frame = 0; PopDataToOutputBuf(confirmed_end_frame, 1, false, true, is_last_frame); } number_end_time_detected++; } void MaybeOnVoiceEndIfLastFrame(bool is_final_frame, int cur_frm_idx) { if (is_final_frame) { OnVoiceEnd(cur_frm_idx, false, true); vad_state_machine = VadStateMachine::kVadInStateEndPointDetected; } } int GetLatency() { return int(LatencyFrmNumAtStartPoint() * vad_opts.frame_in_ms); } int LatencyFrmNumAtStartPoint() { int vad_latency = windows_detector.GetWinSize(); if (vad_opts.do_extend) { vad_latency += int(vad_opts.lookback_time_start_point / vad_opts.frame_in_ms); } return vad_latency; } FrameState GetFrameState(int t) { FrameState frame_state = FrameState::kFrameStateInvalid; float cur_decibel = decibel[t]; float cur_snr = cur_decibel - noise_average_decibel; if (cur_decibel < vad_opts.decibel_thres) { frame_state = FrameState::kFrameStateSil; DetectOneFrame(frame_state, t, false); return frame_state; } float sum_score = 0.0; float noise_prob = 0.0; assert(sil_pdf_ids.size() == vad_opts.silence_pdf_num); if (sil_pdf_ids.size() > 0) { std::vector<float> sil_pdf_scores; for (auto sil_pdf_id: sil_pdf_ids) { sil_pdf_scores.push_back(scores[t][sil_pdf_id]); } sum_score = accumulate(sil_pdf_scores.begin(), sil_pdf_scores.end(), 0.0); noise_prob = log(sum_score) * vad_opts.speech_2_noise_ratio; float total_score = 1.0; sum_score = total_score - sum_score; } float speech_prob = log(sum_score); if (vad_opts.output_frame_probs) { E2EVadFrameProb frame_prob; frame_prob.noise_prob = noise_prob; frame_prob.speech_prob = speech_prob; frame_prob.score = sum_score; frame_prob.frame_id = t; frame_probs.push_back(frame_prob); } if (exp(speech_prob) >= exp(noise_prob) + speech_noise_thres) { if (cur_snr >= vad_opts.snr_thres && cur_decibel >= vad_opts.decibel_thres) { frame_state = FrameState::kFrameStateSpeech; } else { frame_state = FrameState::kFrameStateSil; } } else { frame_state = FrameState::kFrameStateSil; if (noise_average_decibel < -99.9) { noise_average_decibel = cur_decibel; } else { noise_average_decibel = (cur_decibel + noise_average_decibel * (vad_opts.noise_frame_num_used_for_snr - 1)) / vad_opts.noise_frame_num_used_for_snr; } } return frame_state; } int DetectCommonFrames() { if (vad_state_machine == VadStateMachine::kVadInStateEndPointDetected) { return 0; } for (int i = vad_opts.nn_eval_block_size - 1; i >= 0; i--) { FrameState frame_state = FrameState::kFrameStateInvalid; frame_state = GetFrameState(frm_cnt - 1 - i); DetectOneFrame(frame_state, frm_cnt - 1 - i, false); } return 0; } int DetectLastFrames() { if (vad_state_machine == VadStateMachine::kVadInStateEndPointDetected) { return 0; } for (int i = vad_opts.nn_eval_block_size - 1; i >= 0; i--) { FrameState frame_state = FrameState::kFrameStateInvalid; frame_state = GetFrameState(frm_cnt - 1 - i); if (i != 0) { DetectOneFrame(frame_state, frm_cnt - 1 - i, false); } else { DetectOneFrame(frame_state, frm_cnt - 1, true); } } return 0; } void DetectOneFrame(FrameState cur_frm_state, int cur_frm_idx, bool is_final_frame) { FrameState tmp_cur_frm_state = FrameState::kFrameStateInvalid; if (cur_frm_state == FrameState::kFrameStateSpeech) { if (std::fabs(1.0) > vad_opts.fe_prior_thres) { tmp_cur_frm_state = FrameState::kFrameStateSpeech; } else { tmp_cur_frm_state = FrameState::kFrameStateSil; } } else if (cur_frm_state == FrameState::kFrameStateSil) { tmp_cur_frm_state = FrameState::kFrameStateSil; } AudioChangeState state_change = windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx); int frm_shift_in_ms = vad_opts.frame_in_ms; if (AudioChangeState::kChangeStateSil2Speech == state_change) { int silence_frame_count = continous_silence_frame_count; continous_silence_frame_count = 0; pre_end_silence_detected = false; int start_frame = 0; if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) { start_frame = std::max(data_buf_start_frame, cur_frm_idx - LatencyFrmNumAtStartPoint()); OnVoiceStart(start_frame); vad_state_machine = VadStateMachine::kVadInStateInSpeechSegment; for (int t = start_frame + 1; t <= cur_frm_idx; t++) { OnVoiceDetected(t); } } else if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) { for (int t = latest_confirmed_speech_frame + 1; t < cur_frm_idx; t++) { OnVoiceDetected(t); } if (cur_frm_idx - confirmed_start_frame + 1 > vad_opts.max_single_segment_time / frm_shift_in_ms) { OnVoiceEnd(cur_frm_idx, false, false); vad_state_machine = VadStateMachine::kVadInStateEndPointDetected; } else if (!is_final_frame) { OnVoiceDetected(cur_frm_idx); } else { MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx); } } } else if (AudioChangeState::kChangeStateSpeech2Sil == state_change) { continous_silence_frame_count = 0; if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) { // do nothing } else if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) { if (cur_frm_idx - confirmed_start_frame + 1 > vad_opts.max_single_segment_time / frm_shift_in_ms) { OnVoiceEnd(cur_frm_idx, false, false); vad_state_machine = VadStateMachine::kVadInStateEndPointDetected; } else if (!is_final_frame) { OnVoiceDetected(cur_frm_idx); } else { MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx); } } } else if (AudioChangeState::kChangeStateSpeech2Speech == state_change) { continous_silence_frame_count = 0; if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) { if (cur_frm_idx - confirmed_start_frame + 1 > vad_opts.max_single_segment_time / frm_shift_in_ms) { max_time_out = true; OnVoiceEnd(cur_frm_idx, false, false); vad_state_machine = VadStateMachine::kVadInStateEndPointDetected; } else if (!is_final_frame) { OnVoiceDetected(cur_frm_idx); } else { MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx); } } } else if (AudioChangeState::kChangeStateSil2Sil == state_change) { continous_silence_frame_count += 1; if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) { if ((vad_opts.detect_mode == static_cast<int>(VadDetectMode::kVadSingleUtteranceDetectMode) && (continous_silence_frame_count * frm_shift_in_ms > vad_opts.max_start_silence_time)) || (is_final_frame && number_end_time_detected == 0)) { for (int t = lastest_confirmed_silence_frame + 1; t < cur_frm_idx; t++) { OnSilenceDetected(t); } OnVoiceStart(0, true); OnVoiceEnd(0, true, false); vad_state_machine = VadStateMachine::kVadInStateEndPointDetected; } else { if (cur_frm_idx >= LatencyFrmNumAtStartPoint()) { OnSilenceDetected(cur_frm_idx - LatencyFrmNumAtStartPoint()); } } } else if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) { if (continous_silence_frame_count * frm_shift_in_ms >= max_end_sil_frame_cnt_thresh) { int lookback_frame = max_end_sil_frame_cnt_thresh / frm_shift_in_ms; if (vad_opts.do_extend) { lookback_frame -= vad_opts.lookahead_time_end_point / frm_shift_in_ms; lookback_frame -= 1; lookback_frame = std::max(0, lookback_frame); } OnVoiceEnd(cur_frm_idx - lookback_frame, false, false); vad_state_machine = VadStateMachine::kVadInStateEndPointDetected; } else if (cur_frm_idx - confirmed_start_frame + 1 > vad_opts.max_single_segment_time / frm_shift_in_ms) { OnVoiceEnd(cur_frm_idx, false, false); vad_state_machine = VadStateMachine::kVadInStateEndPointDetected; } else if (vad_opts.do_extend && !is_final_frame) { if (continous_silence_frame_count <= vad_opts.lookahead_time_end_point / frm_shift_in_ms) { OnVoiceDetected(cur_frm_idx); } } else { MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx); } } } if (vad_state_machine == VadStateMachine::kVadInStateEndPointDetected && vad_opts.detect_mode == static_cast<int>(VadDetectMode::kVadMutipleUtteranceDetectMode)) { ResetDetection(); } } }; funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -4,14 +4,14 @@ extern "C" { #endif // APIs for qmasr _FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThreadNum, bool quantize) // APIs for funasr _FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThreadNum, bool quantize, bool use_vad) { Model* mm = create_model(szModelDir, nThreadNum, quantize); Model* mm = create_model(szModelDir, nThreadNum, quantize, use_vad); return mm; } _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback) _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad) { Model* pRecogObj = (Model*)handle; if (!pRecogObj) @@ -21,7 +21,9 @@ Audio audio(1); if (!audio.loadwav(szBuf, nLen, &sampling_rate)) return nullptr; //audio.split(); if(use_vad){ audio.split(pRecogObj); } float* buff; int len; @@ -31,7 +33,6 @@ int nStep = 0; int nTotal = audio.get_queue_size(); while (audio.fetch(buff, len, flag) > 0) { //pRecogObj->reset(); string msg = pRecogObj->forward(buff, len, flag); pResult->msg += msg; nStep++; @@ -42,7 +43,7 @@ return pResult; } _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback) _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad) { Model* pRecogObj = (Model*)handle; if (!pRecogObj) @@ -51,7 +52,9 @@ Audio audio(1); if (!audio.loadpcmwav(szBuf, nLen, &sampling_rate)) return nullptr; //audio.split(); if(use_vad){ audio.split(pRecogObj); } float* buff; int len; @@ -61,7 +64,6 @@ int nStep = 0; int nTotal = audio.get_queue_size(); while (audio.fetch(buff, len, flag) > 0) { //pRecogObj->reset(); string msg = pRecogObj->forward(buff, len, flag); pResult->msg += msg; nStep++; @@ -72,7 +74,7 @@ return pResult; } _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback) _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad) { Model* pRecogObj = (Model*)handle; if (!pRecogObj) @@ -81,7 +83,9 @@ Audio audio(1); if (!audio.loadpcmwav(szFileName, &sampling_rate)) return nullptr; //audio.split(); if(use_vad){ audio.split(pRecogObj); } float* buff; int len; @@ -91,7 +95,6 @@ int nStep = 0; int nTotal = audio.get_queue_size(); while (audio.fetch(buff, len, flag) > 0) { //pRecogObj->reset(); string msg = pRecogObj->forward(buff, len, flag); pResult->msg += msg; nStep++; @@ -102,7 +105,7 @@ return pResult; } _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback) _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad) { Model* pRecogObj = (Model*)handle; if (!pRecogObj) @@ -112,7 +115,9 @@ Audio audio(1); if(!audio.loadwav(szWavfile, &sampling_rate)) return nullptr; //audio.split(); if(use_vad){ audio.split(pRecogObj); } float* buff; int len; funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -3,11 +3,19 @@ using namespace std; using namespace paraformer; ModelImp::ModelImp(const char* path,int nNumThread, bool quantize) ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad) :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{ string model_path; string cmvn_path; string config_path; // VAD model if(use_vad){ string vad_path = pathAppend(path, "vad_model.onnx"); string mvn_path = pathAppend(path, "vad.mvn"); vadHandle = make_unique<FsmnVad>(); vadHandle->init_vad(vad_path, mvn_path, model_sample_rate, 800, 15000, 0.9); } if(quantize) { @@ -30,8 +38,10 @@ //fbank_ = std::make_unique<knf::OnlineFbank>(fbank_opts); //sessionOptions.SetInterOpNumThreads(1); //sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); sessionOptions.SetIntraOpNumThreads(nNumThread); sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL); sessionOptions.DisableCpuMemArena(); #ifdef _WIN32 wstring wstrPath = strToWstr(model_path); @@ -67,6 +77,10 @@ void ModelImp::reset() { } vector<std::vector<int>> ModelImp::vad_seg(std::vector<float>& pcm_data){ return vadHandle->infer(pcm_data); } vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) { @@ -172,66 +186,6 @@ p += dim; } } // void ParaformerOnnxAsrModel::ForwardFunc( // const std::vector<std::vector<float>>& chunk_feats, // std::vector<std::vector<float>>* out_prob) { // Ort::MemoryInfo memory_info = // Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); // // 1. Prepare onnx required data, splice cached_feature_ and chunk_feats // // chunk // // int num_frames = cached_feature_.size() + chunk_feats.size(); // int num_frames = chunk_feats.size(); // const int feature_dim = chunk_feats[0].size(); // // 2. Generate 2 input nodes tensor // // speech node { batch,frame number,feature dim } // const int64_t paraformer_feats_shape[3] = {1, num_frames, feature_dim}; // std::vector<float> paraformer_feats; // for (const auto & chunk_feat : chunk_feats) { // paraformer_feats.insert(paraformer_feats.end(), chunk_feat.begin(), chunk_feat.end()); // } // Ort::Value paraformer_feats_ort = Ort::Value::CreateTensor<float>( // memory_info, paraformer_feats.data(), paraformer_feats.size(), paraformer_feats_shape, 3); // // speech_lengths node {speech length,} // const int64_t paraformer_length_shape[1] = {1}; // std::vector<int32_t> paraformer_length; // paraformer_length.emplace_back(num_frames); // Ort::Value paraformer_length_ort = Ort::Value::CreateTensor<int32_t>( // memory_info, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1); // // 3. Put nodes into onnx input vector // std::vector<Ort::Value> paraformer_inputs; // paraformer_inputs.emplace_back(std::move(paraformer_feats_ort)); // paraformer_inputs.emplace_back(std::move(paraformer_length_ort)); // // 4. Onnx infer // std::vector<Ort::Value> paraformer_ort_outputs; // try{ // VLOG(3) << "Start infer"; // paraformer_ort_outputs = paraformer_session_->Run( // Ort::RunOptions{nullptr}, paraformer_in_names_.data(), paraformer_inputs.data(), // paraformer_inputs.size(), paraformer_out_names_.data(), paraformer_out_names_.size()); // }catch (std::exception const& e) { // // Catch "Non-zero status code returned error",usually because there is no asr result. // // Need funasr to resolve. // LOG(ERROR) << e.what(); // return; // } // // 5. Change infer result to output shapes // float* logp_data = paraformer_ort_outputs[0].GetTensorMutableData<float>(); // auto type_info = paraformer_ort_outputs[0].GetTensorTypeAndShapeInfo(); // int num_outputs = type_info.GetShape()[1]; // int output_dim = type_info.GetShape()[2]; // out_prob->resize(num_outputs); // for (int i = 0; i < num_outputs; i++) { // (*out_prob)[i].resize(output_dim); // memcpy((*out_prob)[i].data(), logp_data + i * output_dim, // sizeof(float) * output_dim); // } // } string ModelImp::forward(float* din, int len, int flag) { funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -4,8 +4,7 @@ #ifndef PARAFORMER_MODELIMP_H #define PARAFORMER_MODELIMP_H #include "kaldi-native-fbank/csrc/feature-fbank.h" #include "kaldi-native-fbank/csrc/online-feature.h" #include "precomp.h" namespace paraformer { @@ -13,6 +12,8 @@ private: //std::unique_ptr<knf::OnlineFbank> fbank_; knf::FbankOptions fbank_opts; std::unique_ptr<FsmnVad> vadHandle; Vocab* vocab; vector<float> means_list; @@ -27,7 +28,7 @@ string greedy_search( float* in, int nLen); std::unique_ptr<Ort::Session> m_session; std::shared_ptr<Ort::Session> m_session; Ort::Env env_; Ort::SessionOptions sessionOptions; @@ -36,13 +37,14 @@ vector<const char*> m_szOutputNames; public: ModelImp(const char* path, int nNumThread=0, bool quantize=false); ModelImp(const char* path, int nNumThread=0, bool quantize=false, bool use_vad=false); ~ModelImp(); void reset(); vector<float> FbankKaldi(float sample_rate, const float* waves, int len); string forward_chunk(float* din, int len, int flag); string forward(float* din, int len, int flag); string rescoring(); std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data); }; funasr/runtime/onnxruntime/src/precomp.h
@@ -26,6 +26,8 @@ #include <fftw3.h> #include "onnxruntime_run_options_config_keys.h" #include "onnxruntime_cxx_api.h" #include "kaldi-native-fbank/csrc/feature-fbank.h" #include "kaldi-native-fbank/csrc/online-feature.h" // mine @@ -33,6 +35,7 @@ #include "commonfunc.h" #include <ComDefine.h> #include "predefine_coe.h" #include "FsmnVad.h" #include <ComDefine.h> //#include "alignedmem.h" funasr/runtime/onnxruntime/tester/tester.cpp
@@ -14,9 +14,9 @@ int main(int argc, char *argv[]) { if (argc < 4) if (argc < 5) { printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) \n", argv[0]); printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) use_vad(true or false) \n", argv[0]); exit(-1); } struct timeval start, end; @@ -24,8 +24,10 @@ int nThreadNum = 1; // is quantize bool quantize = false; bool use_vad = false; istringstream(argv[3]) >> boolalpha >> quantize; FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize); istringstream(argv[4]) >> boolalpha >> use_vad; FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize, use_vad); if (!AsrHanlde) { @@ -41,7 +43,7 @@ gettimeofday(&start, NULL); float snippet_time = 0.0f; FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL); FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL, use_vad); gettimeofday(&end, NULL); funasr/runtime/onnxruntime/tester/tester_rtf.cpp
@@ -19,7 +19,7 @@ std::atomic<int> index(0); std::mutex mtx; void runReg(FUNASR_HANDLE AsrHanlde, vector<string> wav_list, void runReg(FUNASR_HANDLE AsrHandle, vector<string> wav_list, float* total_length, long* total_time, int core_id) { // cpu_set_t cpuset; @@ -37,7 +37,7 @@ // warm up for (size_t i = 0; i < 1; i++) { FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL); FUNASR_RESULT Result=FunASRRecogFile(AsrHandle, wav_list[0].c_str(), RASR_NONE, NULL); } while (true) { @@ -48,7 +48,7 @@ } gettimeofday(&start, NULL); FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL); FUNASR_RESULT Result=FunASRRecogFile(AsrHandle, wav_list[i].c_str(), RASR_NONE, NULL); gettimeofday(&end, NULL); seconds = (end.tv_sec - start.tv_sec); @@ -112,8 +112,8 @@ int nThreadNum = 1; nThreadNum = atoi(argv[4]); FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], 1, quantize); if (!AsrHanlde) FUNASR_HANDLE AsrHandle=FunASRInit(argv[1], 1, quantize); if (!AsrHandle) { printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]); exit(-1); @@ -130,7 +130,7 @@ for (int i = 0; i < nThreadNum; i++) { threads.emplace_back(thread(runReg, AsrHanlde, wav_list, &total_length, &total_time, i)); threads.emplace_back(thread(runReg, AsrHandle, wav_list, &total_length, &total_time, i)); } for (auto& thread : threads) @@ -142,6 +142,6 @@ printf("total_time_comput %ld ms.\n", total_time / 1000); printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000)); FunASRUninit(AsrHanlde); FunASRUninit(AsrHandle); return 0; }