From 1819303f5e8cfc03f4c0ec2495571a54a186d34b Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期二, 29 十月 2024 11:40:18 +0800
Subject: [PATCH] support SenseVoiceSmall in 2pass mode
---
runtime/onnxruntime/include/funasrruntime.h | 3
runtime/onnxruntime/include/tpass-stream.h | 2
runtime/onnxruntime/src/paraformer.h | 3
runtime/onnxruntime/src/sensevoice-small.h | 12 +
runtime/websocket/bin/websocket-server-2pass.h | 4
runtime/websocket/bin/websocket-server-2pass.cpp | 26 ++
runtime/onnxruntime/include/model.h | 4
runtime/onnxruntime/src/paraformer-online.h | 23 ++
runtime/onnxruntime/src/tpass-online-stream.cpp | 2
runtime/onnxruntime/src/paraformer.cpp | 27 ---
runtime/websocket/bin/funasr-wss-client-2pass.cpp | 33 ++-
runtime/onnxruntime/src/sensevoice-small.cpp | 166 ++++++++++++++++++++
runtime/onnxruntime/src/paraformer-online.cpp | 106 +++++++++---
runtime/onnxruntime/src/tpass-stream.cpp | 12 +
runtime/onnxruntime/src/funasrruntime.cpp | 49 +++--
runtime/websocket/bin/funasr-wss-server-2pass.cpp | 6
16 files changed, 375 insertions(+), 103 deletions(-)
diff --git a/runtime/onnxruntime/include/funasrruntime.h b/runtime/onnxruntime/include/funasrruntime.h
index 5dedaf7..685c024 100644
--- a/runtime/onnxruntime/include/funasrruntime.h
+++ b/runtime/onnxruntime/include/funasrruntime.h
@@ -120,7 +120,8 @@
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true,
int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS,
- const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
+ const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr,
+ std::string svs_lang="auto", bool svs_itn=true);
_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle);
_FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle);
diff --git a/runtime/onnxruntime/include/model.h b/runtime/onnxruntime/include/model.h
index a49baeb..5ce1148 100644
--- a/runtime/onnxruntime/include/model.h
+++ b/runtime/onnxruntime/include/model.h
@@ -16,9 +16,11 @@
virtual void StartUtterance() = 0;
virtual void EndUtterance() = 0;
virtual void Reset() = 0;
+ virtual string GreedySearch(float* in, int n_len, int64_t token_nums, bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0}){return "";};
virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
- virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
+ virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn,
+ const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num){};
virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
virtual void InitFstDecoder(){};
virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
diff --git a/runtime/onnxruntime/include/tpass-stream.h b/runtime/onnxruntime/include/tpass-stream.h
index 0276631..a4640a2 100644
--- a/runtime/onnxruntime/include/tpass-stream.h
+++ b/runtime/onnxruntime/include/tpass-stream.h
@@ -26,11 +26,13 @@
bool UseVad(){return use_vad;};
bool UsePunc(){return use_punc;};
bool UseITN(){return use_itn;};
+ std::string GetModelType(){return model_type;};
private:
bool use_vad=false;
bool use_punc=false;
bool use_itn=false;
+ std::string model_type = MODEL_PARA;
};
TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, int thread_num=1);
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 88a3970..6286412 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -482,7 +482,8 @@
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished,
int sampling_rate, std::string wav_format, ASR_TYPE mode,
- const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle)
+ const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle,
+ std::string svs_lang, bool svs_itn)
{
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
@@ -525,7 +526,7 @@
funasr::AudioFrame* frame = nullptr;
while(audio->FetchChunck(frame) > 0){
- string msg = ((funasr::ParaformerOnline*)asr_online_handle)->Forward(frame->data, frame->len, frame->is_final);
+ string msg = (asr_online_handle)->Forward(frame->data, frame->len, frame->is_final);
if(mode == ASR_ONLINE){
((funasr::ParaformerOnline*)asr_online_handle)->online_res += msg;
if(frame->is_final){
@@ -567,7 +568,12 @@
len = new int[1];
buff[0] = frame->data;
len[0] = frame->len;
- vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
+ vector<string> msgs;
+ if(tpass_stream->GetModelType() == MODEL_SVS){
+ msgs = (tpass_stream->asr_handle)->Forward(buff, len, true, svs_lang, svs_itn, 1);
+ }else{
+ msgs = (tpass_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, 1);
+ }
string msg = msgs.size()>0?msgs[0]:"";
std::vector<std::string> msg_vec = funasr::SplitStr(msg, " | "); // split with timestamp
if(msg_vec.size()==0){
@@ -589,24 +595,29 @@
p_result->stamp += cur_stamp + "]";
}
- string msg_punc = punc_online_handle->AddPunc(msg.c_str(), punc_cache[1]);
- if(input_finished){
- msg_punc += "銆�";
- }
- p_result->tpass_msg = msg_punc;
-#if !defined(__APPLE__)
- if(tpass_stream->UseITN() && itn){
- string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
- // TimestampSmooth
- if(!(p_result->stamp).empty()){
- std::string new_stamp = funasr::TimestampSmooth(p_result->tpass_msg, msg_itn, p_result->stamp);
- if(!new_stamp.empty()){
- p_result->stamp = new_stamp;
- }
+ if (tpass_stream->GetModelType() == MODEL_PARA){
+ string msg_punc = punc_online_handle->AddPunc(msg.c_str(), punc_cache[1]);
+ if(input_finished){
+ msg_punc += "銆�";
}
- p_result->tpass_msg = msg_itn;
- }
+ p_result->tpass_msg = msg_punc;
+
+#if !defined(__APPLE__)
+ if(tpass_stream->UseITN() && itn){
+ string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
+ // TimestampSmooth
+ if(!(p_result->stamp).empty()){
+ std::string new_stamp = funasr::TimestampSmooth(p_result->tpass_msg, msg_itn, p_result->stamp);
+ if(!new_stamp.empty()){
+ p_result->stamp = new_stamp;
+ }
+ }
+ p_result->tpass_msg = msg_itn;
+ }
#endif
+ }else{
+ p_result->tpass_msg = msg;
+ }
if (!(p_result->stamp).empty()){
p_result->stamp_sents = funasr::TimestampSentence(p_result->tpass_msg, p_result->stamp);
}
diff --git a/runtime/onnxruntime/src/paraformer-online.cpp b/runtime/onnxruntime/src/paraformer-online.cpp
index 55a4fd1..88951aa 100644
--- a/runtime/onnxruntime/src/paraformer-online.cpp
+++ b/runtime/onnxruntime/src/paraformer-online.cpp
@@ -9,18 +9,55 @@
namespace funasr {
-ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
-:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
- InitOnline(
- para_handle_->fbank_opts_,
- para_handle_->encoder_session_,
- para_handle_->decoder_session_,
- para_handle_->en_szInputNames_,
- para_handle_->en_szOutputNames_,
- para_handle_->de_szInputNames_,
- para_handle_->de_szOutputNames_,
- para_handle_->means_list_,
- para_handle_->vars_list_);
+ParaformerOnline::ParaformerOnline(Model* offline_handle, std::vector<int> chunk_size, std::string model_type)
+:offline_handle_(std::move(offline_handle)),chunk_size(chunk_size),session_options_{}{
+ if(model_type == MODEL_PARA){
+ Paraformer* para_handle = dynamic_cast<Paraformer*>(offline_handle_);
+ InitOnline(
+ para_handle->fbank_opts_,
+ para_handle->encoder_session_,
+ para_handle->decoder_session_,
+ para_handle->en_szInputNames_,
+ para_handle->en_szOutputNames_,
+ para_handle->de_szInputNames_,
+ para_handle->de_szOutputNames_,
+ para_handle->means_list_,
+ para_handle->vars_list_,
+ para_handle->frame_length,
+ para_handle->frame_shift,
+ para_handle->n_mels,
+ para_handle->lfr_m,
+ para_handle->lfr_n,
+ para_handle->encoder_size,
+ para_handle->fsmn_layers,
+ para_handle->fsmn_lorder,
+ para_handle->fsmn_dims,
+ para_handle->cif_threshold,
+ para_handle->tail_alphas);
+ }else if(model_type == MODEL_SVS){
+ SenseVoiceSmall* svs_handle = dynamic_cast<SenseVoiceSmall*>(offline_handle_);
+ InitOnline(
+ svs_handle->fbank_opts_,
+ svs_handle->encoder_session_,
+ svs_handle->decoder_session_,
+ svs_handle->en_szInputNames_,
+ svs_handle->en_szOutputNames_,
+ svs_handle->de_szInputNames_,
+ svs_handle->de_szOutputNames_,
+ svs_handle->means_list_,
+ svs_handle->vars_list_,
+ svs_handle->frame_length,
+ svs_handle->frame_shift,
+ svs_handle->n_mels,
+ svs_handle->lfr_m,
+ svs_handle->lfr_n,
+ svs_handle->encoder_size,
+ svs_handle->fsmn_layers,
+ svs_handle->fsmn_lorder,
+ svs_handle->fsmn_dims,
+ svs_handle->cif_threshold,
+ svs_handle->tail_alphas);
+ }
InitCache();
}
@@ -33,7 +70,18 @@
vector<const char*> &de_szInputNames,
vector<const char*> &de_szOutputNames,
vector<float> &means_list,
- vector<float> &vars_list){
+ vector<float> &vars_list,
+ int frame_length_,
+ int frame_shift_,
+ int n_mels_,
+ int lfr_m_,
+ int lfr_n_,
+ int encoder_size_,
+ int fsmn_layers_,
+ int fsmn_lorder_,
+ int fsmn_dims_,
+ float cif_threshold_,
+ float tail_alphas_){
fbank_opts_ = fbank_opts;
encoder_session_ = encoder_session;
decoder_session_ = decoder_session;
@@ -44,27 +92,27 @@
means_list_ = means_list;
vars_list_ = vars_list;
- frame_length = para_handle_->frame_length;
- frame_shift = para_handle_->frame_shift;
- n_mels = para_handle_->n_mels;
- lfr_m = para_handle_->lfr_m;
- lfr_n = para_handle_->lfr_n;
- encoder_size = para_handle_->encoder_size;
- fsmn_layers = para_handle_->fsmn_layers;
- fsmn_lorder = para_handle_->fsmn_lorder;
- fsmn_dims = para_handle_->fsmn_dims;
- cif_threshold = para_handle_->cif_threshold;
- tail_alphas = para_handle_->tail_alphas;
+ frame_length = frame_length_;
+ frame_shift = frame_shift_;
+ n_mels = n_mels_;
+ lfr_m = lfr_m_;
+ lfr_n = lfr_n_;
+ encoder_size = encoder_size_;
+ fsmn_layers = fsmn_layers_;
+ fsmn_lorder = fsmn_lorder_;
+ fsmn_dims = fsmn_dims_;
+ cif_threshold = cif_threshold_;
+ tail_alphas = tail_alphas_;
// other vars
sqrt_factor = std::sqrt(encoder_size);
for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
fsmn_init_cache_.emplace_back(0);
}
- chunk_len = chunk_size[1]*frame_shift*lfr_n*para_handle_->asr_sample_rate/1000;
+ chunk_len = chunk_size[1]*frame_shift*lfr_n*offline_handle_->GetAsrSampleRate()/1000;
- frame_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_length;
- frame_shift_sample_length_ = para_handle_->asr_sample_rate / 1000 * frame_shift;
+ frame_sample_length_ = offline_handle_->GetAsrSampleRate() / 1000 * frame_length;
+ frame_shift_sample_length_ = offline_handle_->GetAsrSampleRate() / 1000 * frame_shift;
}
@@ -464,7 +512,7 @@
std::vector<int64_t> decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
float* float_data = decoder_tensor[0].GetTensorMutableData<float>();
- result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
+ result = offline_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
}
}catch (std::exception const &e)
{
@@ -493,7 +541,7 @@
if(is_first_chunk){
is_first_chunk = false;
}
- ExtractFeats(para_handle_->asr_sample_rate, wav_feats, waves, input_finished);
+ ExtractFeats(offline_handle_->GetAsrSampleRate(), wav_feats, waves, input_finished);
if(wav_feats.size() == 0){
return result;
}
diff --git a/runtime/onnxruntime/src/paraformer-online.h b/runtime/onnxruntime/src/paraformer-online.h
index 8c9bb88..8ab473d 100644
--- a/runtime/onnxruntime/src/paraformer-online.h
+++ b/runtime/onnxruntime/src/paraformer-online.h
@@ -38,7 +38,18 @@
vector<const char*> &de_szInputNames,
vector<const char*> &de_szOutputNames,
vector<float> &means_list,
- vector<float> &vars_list);
+ vector<float> &vars_list,
+ int frame_length_,
+ int frame_shift_,
+ int n_mels_,
+ int lfr_m_,
+ int lfr_n_,
+ int encoder_size_,
+ int fsmn_layers_,
+ int fsmn_lorder_,
+ int fsmn_dims_,
+ float cif_threshold_,
+ float tail_alphas_);
void StartUtterance()
{
@@ -48,8 +59,8 @@
{
}
- Paraformer* para_handle_ = nullptr;
- // from para_handle_
+ Model* offline_handle_ = nullptr;
+ // from offline_handle_
knf::FbankOptions fbank_opts_;
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
@@ -60,7 +71,7 @@
vector<const char*> de_szOutputNames_;
vector<float> means_list_;
vector<float> vars_list_;
- // configs from para_handle_
+ // configs from offline_handle_
int frame_length = 25;
int frame_shift = 10;
int n_mels = 80;
@@ -100,7 +111,7 @@
double sqrt_factor;
public:
- ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size);
+ ParaformerOnline(Model* offline_handle, std::vector<int> chunk_size, std::string model_type=MODEL_PARA);
~ParaformerOnline();
void Reset();
void ResetCache();
@@ -112,7 +123,7 @@
string Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
string Rescoring();
- int GetAsrSampleRate() { return para_handle_->asr_sample_rate; };
+ int GetAsrSampleRate() { return offline_handle_->GetAsrSampleRate(); };
// 2pass
std::string online_res;
diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index 24f5152..7e1fe96 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -131,9 +131,10 @@
}
// 2pass
-void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
+void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model,
+ const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num){
// online
- InitAsr(en_model, de_model, am_cmvn, am_config, token_file, thread_num);
+ InitAsr(en_model, de_model, am_cmvn, am_config, online_token_file, thread_num);
// offline
try {
@@ -144,28 +145,6 @@
exit(-1);
}
- // string strName;
- // GetInputName(m_session_.get(), strName);
- // m_strInputNames.push_back(strName.c_str());
- // GetInputName(m_session_.get(), strName,1);
- // m_strInputNames.push_back(strName);
-
- // if (use_hotword) {
- // GetInputName(m_session_.get(), strName, 2);
- // m_strInputNames.push_back(strName);
- // }
-
- // // support time stamp
- // size_t numOutputNodes = m_session_->GetOutputCount();
- // for(int index=0; index<numOutputNodes; index++){
- // GetOutputName(m_session_.get(), strName, index);
- // m_strOutputNames.push_back(strName);
- // }
-
- // for (auto& item : m_strInputNames)
- // m_szInputNames.push_back(item.c_str());
- // for (auto& item : m_strOutputNames)
- // m_szOutputNames.push_back(item.c_str());
GetInputNames(m_session_.get(), m_strInputNames, m_szInputNames);
GetOutputNames(m_session_.get(), m_strOutputNames, m_szOutputNames);
}
diff --git a/runtime/onnxruntime/src/paraformer.h b/runtime/onnxruntime/src/paraformer.h
index 571b2ba..41e71f5 100644
--- a/runtime/onnxruntime/src/paraformer.h
+++ b/runtime/onnxruntime/src/paraformer.h
@@ -46,7 +46,8 @@
// online
void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
// 2pass
- void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
+ void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn,
+ const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num);
void InitHwCompiler(const std::string &hw_model, int thread_num);
void InitSegDict(const std::string &seg_dict_model);
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
diff --git a/runtime/onnxruntime/src/sensevoice-small.cpp b/runtime/onnxruntime/src/sensevoice-small.cpp
index 9fa72a0..5cb1042 100644
--- a/runtime/onnxruntime/src/sensevoice-small.cpp
+++ b/runtime/onnxruntime/src/sensevoice-small.cpp
@@ -48,6 +48,145 @@
LoadCmvn(am_cmvn.c_str());
}
+// online
+void SenseVoiceSmall::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
+
+ LoadOnlineConfigFromYaml(am_config.c_str());
+ // knf options
+ fbank_opts_.frame_opts.dither = 0;
+ fbank_opts_.mel_opts.num_bins = n_mels;
+ fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
+ fbank_opts_.frame_opts.window_type = window_type;
+ fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
+ fbank_opts_.frame_opts.frame_length_ms = frame_length;
+ fbank_opts_.energy_floor = 0;
+ fbank_opts_.mel_opts.debug_mel = false;
+
+ // session_options_.SetInterOpNumThreads(1);
+ session_options_.SetIntraOpNumThreads(thread_num);
+ session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
+ // DisableCpuMemArena can improve performance
+ session_options_.DisableCpuMemArena();
+
+ try {
+ encoder_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(en_model).c_str(), session_options_);
+ LOG(INFO) << "Successfully load model from " << en_model;
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am encoder model: " << e.what();
+ exit(-1);
+ }
+
+ try {
+ decoder_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(de_model).c_str(), session_options_);
+ LOG(INFO) << "Successfully load model from " << de_model;
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am decoder model: " << e.what();
+ exit(-1);
+ }
+
+ // encoder
+ string strName;
+ GetInputName(encoder_session_.get(), strName);
+ en_strInputNames.push_back(strName.c_str());
+ GetInputName(encoder_session_.get(), strName,1);
+ en_strInputNames.push_back(strName);
+
+ GetOutputName(encoder_session_.get(), strName);
+ en_strOutputNames.push_back(strName);
+ GetOutputName(encoder_session_.get(), strName,1);
+ en_strOutputNames.push_back(strName);
+ GetOutputName(encoder_session_.get(), strName,2);
+ en_strOutputNames.push_back(strName);
+
+ for (auto& item : en_strInputNames)
+ en_szInputNames_.push_back(item.c_str());
+ for (auto& item : en_strOutputNames)
+ en_szOutputNames_.push_back(item.c_str());
+
+ // decoder
+ int de_input_len = 4 + fsmn_layers;
+ int de_out_len = 2 + fsmn_layers;
+ for(int i=0;i<de_input_len; i++){
+ GetInputName(decoder_session_.get(), strName, i);
+ de_strInputNames.push_back(strName.c_str());
+ }
+
+ for(int i=0;i<de_out_len; i++){
+ GetOutputName(decoder_session_.get(), strName,i);
+ de_strOutputNames.push_back(strName);
+ }
+
+ for (auto& item : de_strInputNames)
+ de_szInputNames_.push_back(item.c_str());
+ for (auto& item : de_strOutputNames)
+ de_szOutputNames_.push_back(item.c_str());
+
+ online_vocab = new Vocab(token_file.c_str());
+ phone_set_ = new PhoneSet(token_file.c_str());
+ LoadCmvn(am_cmvn.c_str());
+}
+
+// 2pass
+void SenseVoiceSmall::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model,
+ const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, const std::string &online_token_file, int thread_num){
+ // online
+ InitAsr(en_model, de_model, am_cmvn, am_config, online_token_file, thread_num);
+
+ // offline
+ try {
+ m_session_ = std::make_unique<Ort::Session>(env_, ORTSTRING(am_model).c_str(), session_options_);
+ LOG(INFO) << "Successfully load model from " << am_model;
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am onnx model: " << e.what();
+ exit(-1);
+ }
+
+ GetInputNames(m_session_.get(), m_strInputNames, m_szInputNames);
+ GetOutputNames(m_session_.get(), m_strOutputNames, m_szOutputNames);
+ vocab = new Vocab(token_file.c_str());
+}
+
+void SenseVoiceSmall::LoadOnlineConfigFromYaml(const char* filename){
+
+ YAML::Node config;
+ try{
+ config = YAML::LoadFile(filename);
+ }catch(exception const &e){
+ LOG(ERROR) << "Error loading file, yaml file error or not exist.";
+ exit(-1);
+ }
+
+ try{
+ YAML::Node frontend_conf = config["frontend_conf"];
+ YAML::Node encoder_conf = config["encoder_conf"];
+ YAML::Node decoder_conf = config["decoder_conf"];
+ YAML::Node predictor_conf = config["predictor_conf"];
+
+ this->window_type = frontend_conf["window"].as<string>();
+ this->n_mels = frontend_conf["n_mels"].as<int>();
+ this->frame_length = frontend_conf["frame_length"].as<int>();
+ this->frame_shift = frontend_conf["frame_shift"].as<int>();
+ this->lfr_m = frontend_conf["lfr_m"].as<int>();
+ this->lfr_n = frontend_conf["lfr_n"].as<int>();
+
+ this->encoder_size = encoder_conf["output_size"].as<int>();
+ this->fsmn_dims = encoder_conf["output_size"].as<int>();
+
+ this->fsmn_layers = decoder_conf["num_blocks"].as<int>();
+ this->fsmn_lorder = decoder_conf["kernel_size"].as<int>()-1;
+
+ this->cif_threshold = predictor_conf["threshold"].as<double>();
+ this->tail_alphas = predictor_conf["tail_threshold"].as<double>();
+
+ this->asr_sample_rate = frontend_conf["fs"].as<int>();
+
+
+ }catch(exception const &e){
+ LOG(ERROR) << "Error when load argument from vad config YAML.";
+ exit(-1);
+ }
+}
+
void SenseVoiceSmall::LoadConfigFromYaml(const char* filename){
YAML::Node config;
@@ -83,6 +222,9 @@
{
if(vocab){
delete vocab;
+ }
+ if(online_vocab){
+ delete online_vocab;
}
if(lm_vocab){
delete lm_vocab;
@@ -212,6 +354,30 @@
return str_lang + str_emo + str_event + " " + text;
}
+string SenseVoiceSmall::GreedySearch(float * in, int n_len, int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
+{
+ vector<int> hyps;
+ int Tmax = n_len;
+ for (int i = 0; i < Tmax; i++) {
+ int max_idx;
+ float max_val;
+ FindMax(in + i * token_nums, token_nums, max_val, max_idx);
+ hyps.push_back(max_idx);
+ }
+ if(!is_stamp){
+ return online_vocab->Vector2StringV2(hyps, language);
+ }else{
+ std::vector<string> char_list;
+ std::vector<std::vector<float>> timestamp_list;
+ std::string res_str;
+ online_vocab->Vector2String(hyps, char_list);
+ std::vector<string> raw_char(char_list);
+ TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list);
+
+ return PostProcess(raw_char, timestamp_list);
+ }
+}
+
void SenseVoiceSmall::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
std::vector<std::vector<float>> out_feats;
diff --git a/runtime/onnxruntime/src/sensevoice-small.h b/runtime/onnxruntime/src/sensevoice-small.h
index f987f38..75cbc92 100644
--- a/runtime/onnxruntime/src/sensevoice-small.h
+++ b/runtime/onnxruntime/src/sensevoice-small.h
@@ -12,12 +12,14 @@
class SenseVoiceSmall : public Model {
private:
Vocab* vocab = nullptr;
+ Vocab* online_vocab = nullptr;
Vocab* lm_vocab = nullptr;
SegDict* seg_dict = nullptr;
PhoneSet* phone_set_ = nullptr;
const float scale = 1.0;
void LoadConfigFromYaml(const char* filename);
+ void LoadOnlineConfigFromYaml(const char* filename);
void LoadCmvn(const char *filename);
void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
@@ -34,9 +36,10 @@
~SenseVoiceSmall();
void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
// online
- // void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
+ void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
// 2pass
- // void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
+ void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config,
+ const std::string &token_file, const std::string &online_token_file, int thread_num);
// void InitHwCompiler(const std::string &hw_model, int thread_num);
// void InitSegDict(const std::string &seg_dict_model);
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
@@ -44,7 +47,8 @@
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, std::string svs_lang="auto", bool svs_itn=true, int batch_in=1);
string CTCSearch( float * in, std::vector<int32_t> paraformer_length, std::vector<int64_t> outputShape);
-
+ string GreedySearch( float* in, int n_len, int64_t token_nums,
+ bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
string Rescoring();
string GetLang(){return language;};
int GetAsrSampleRate() { return asr_sample_rate; };
@@ -100,6 +104,8 @@
int asr_sample_rate = MODEL_SAMPLE_RATE;
int batch_size_ = 1;
int blank_id = 0;
+ float cif_threshold = 1.0;
+ float tail_alphas = 0.45;
//dict
std::map<std::string, int> lid_map = {
{"auto", 0},
diff --git a/runtime/onnxruntime/src/tpass-online-stream.cpp b/runtime/onnxruntime/src/tpass-online-stream.cpp
index 7788e0b..338bb2b 100644
--- a/runtime/onnxruntime/src/tpass-online-stream.cpp
+++ b/runtime/onnxruntime/src/tpass-online-stream.cpp
@@ -11,7 +11,7 @@
}
if(tpass_obj->asr_handle){
- asr_online_handle = make_unique<ParaformerOnline>((Paraformer*)(tpass_obj->asr_handle).get(), chunk_size);
+ asr_online_handle = make_unique<ParaformerOnline>((tpass_obj->asr_handle).get(), chunk_size, tpass_stream->GetModelType());
}else{
LOG(ERROR)<<"asr_handle is null";
exit(-1);
diff --git a/runtime/onnxruntime/src/tpass-stream.cpp b/runtime/onnxruntime/src/tpass-stream.cpp
index 7681a4d..ff502de 100644
--- a/runtime/onnxruntime/src/tpass-stream.cpp
+++ b/runtime/onnxruntime/src/tpass-stream.cpp
@@ -36,10 +36,17 @@
string am_cmvn_path;
string am_config_path;
string token_path;
+ string online_token_path;
string hw_compile_model_path;
string seg_dict_path;
- asr_handle = make_unique<Paraformer>();
+ if (model_path.at(MODEL_DIR).find(MODEL_SVS) != std::string::npos)
+ {
+ asr_handle = make_unique<SenseVoiceSmall>();
+ model_type = MODEL_SVS;
+ }else{
+ asr_handle = make_unique<Paraformer>();
+ }
bool enable_hotword = false;
hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
@@ -54,6 +61,7 @@
am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), MODEL_NAME);
en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), ENCODER_NAME);
de_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), DECODER_NAME);
+ online_token_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), TOKEN_PATH);
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), QUANT_MODEL_NAME);
en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), QUANT_ENCODER_NAME);
@@ -63,7 +71,7 @@
am_config_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CONFIG_NAME);
token_path = PathAppend(model_path.at(MODEL_DIR), TOKEN_PATH);
- asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, token_path, thread_num);
+ asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, token_path, online_token_path, thread_num);
}else{
LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
exit(-1);
diff --git a/runtime/websocket/bin/funasr-wss-client-2pass.cpp b/runtime/websocket/bin/funasr-wss-client-2pass.cpp
index 6c3a4dd..e8bbfc1 100644
--- a/runtime/websocket/bin/funasr-wss-client-2pass.cpp
+++ b/runtime/websocket/bin/funasr-wss-client-2pass.cpp
@@ -124,7 +124,7 @@
void run(const std::string& uri, const std::vector<string>& wav_list,
const std::vector<string>& wav_ids, int audio_fs, std::string asr_mode,
std::vector<int> chunk_size, const std::unordered_map<std::string, int>& hws_map,
- bool is_record=false, int use_itn=1) {
+ bool is_record=false, int use_itn=1, int svs_itn=1) {
// Create a new connection to the given URI
websocketpp::lib::error_code ec;
typename websocketpp::client<T>::connection_ptr con =
@@ -146,9 +146,9 @@
websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
&m_client);
if(is_record){
- send_rec_data(asr_mode, chunk_size, hws_map, use_itn);
+ send_rec_data(asr_mode, chunk_size, hws_map, use_itn, svs_itn);
}else{
- send_wav_data(wav_list[0], wav_ids[0], audio_fs, asr_mode, chunk_size, hws_map, use_itn);
+ send_wav_data(wav_list[0], wav_ids[0], audio_fs, asr_mode, chunk_size, hws_map, use_itn, svs_itn);
}
WaitABit();
@@ -185,7 +185,7 @@
// send wav to server
void send_wav_data(string wav_path, string wav_id, int audio_fs, std::string asr_mode,
std::vector<int> chunk_vector, const std::unordered_map<std::string, int>& hws_map,
- int use_itn) {
+ int use_itn, int svs_itn) {
uint64_t count = 0;
std::stringstream val;
@@ -241,8 +241,12 @@
jsonbegin["audio_fs"] = sampling_rate;
jsonbegin["is_speaking"] = true;
jsonbegin["itn"] = true;
+ jsonbegin["svs_itn"] = true;
if(use_itn == 0){
jsonbegin["itn"] = false;
+ }
+ if(svs_itn == 0){
+ jsonbegin["svs_itn"] = false;
}
if(!hws_map.empty()){
LOG(INFO) << "hotwords: ";
@@ -335,7 +339,7 @@
}
void send_rec_data(std::string asr_mode, std::vector<int> chunk_vector,
- const std::unordered_map<std::string, int>& hws_map, int use_itn) {
+ const std::unordered_map<std::string, int>& hws_map, int use_itn, int svs_itn) {
// first message
bool wait = false;
while (1) {
@@ -374,8 +378,12 @@
jsonbegin["audio_fs"] = sample_rate;
jsonbegin["is_speaking"] = true;
jsonbegin["itn"] = true;
+ jsonbegin["svs_itn"] = true;
if(use_itn == 0){
jsonbegin["itn"] = false;
+ }
+ if(svs_itn == 0){
+ jsonbegin["svs_itn"] = false;
}
if(!hws_map.empty()){
LOG(INFO) << "hotwords: ";
@@ -513,6 +521,9 @@
"", "use-itn",
"use-itn is 1 means use itn, 0 means not use itn", false, 1,
"int");
+ TCLAP::ValueArg<int> svs_itn_(
+ "", "svs-itn",
+ "svs-itn is 1 means use itn and punc, 0 means not use", false, 1, "int");
TCLAP::ValueArg<std::string> hotword_("", HOTWORD,
"the hotword file, one hotword perline, Format: Hotword Weight (could be: 闃块噷宸村反 20)", false, "", "string");
@@ -526,6 +537,7 @@
cmd.add(thread_num_);
cmd.add(is_ssl_);
cmd.add(use_itn_);
+ cmd.add(svs_itn_);
cmd.add(hotword_);
cmd.parse(argc, argv);
@@ -535,6 +547,7 @@
std::string asr_mode = asr_mode_.getValue();
std::string chunk_size_str = chunk_size_.getValue();
int use_itn = use_itn_.getValue();
+ int svs_itn = svs_itn_.getValue();
// get chunk_size
std::vector<int> chunk_size;
std::stringstream ss(chunk_size_str);
@@ -577,11 +590,11 @@
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
- c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn);
+ c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn, svs_itn);
} else {
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
- c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn);
+ c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn, svs_itn);
}
}else{
@@ -622,17 +635,17 @@
tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
client_threads.emplace_back(
- [uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, is_ssl, hws_map, use_itn]() {
+ [uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, is_ssl, hws_map, use_itn, svs_itn]() {
if (is_ssl == 1) {
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
- c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn);
+ c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn, svs_itn);
} else {
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
- c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn);
+ c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn, svs_itn);
}
});
}
diff --git a/runtime/websocket/bin/funasr-wss-server-2pass.cpp b/runtime/websocket/bin/funasr-wss-server-2pass.cpp
index 9c59254..edf614a 100644
--- a/runtime/websocket/bin/funasr-wss-server-2pass.cpp
+++ b/runtime/websocket/bin/funasr-wss-server-2pass.cpp
@@ -276,6 +276,12 @@
s_itn_path="";
s_lm_path="";
}
+ found = s_offline_asr_path.find(MODEL_SVS);
+ if (found != std::string::npos) {
+ model_path["model-revision"]="v2.0.5";
+ s_lm_path="";
+ model_path[LM_DIR]="";
+ }
if (access(s_offline_asr_path.c_str(), F_OK) == 0) {
// local
diff --git a/runtime/websocket/bin/websocket-server-2pass.cpp b/runtime/websocket/bin/websocket-server-2pass.cpp
index 8c8cab4..ff23e9d 100644
--- a/runtime/websocket/bin/websocket-server-2pass.cpp
+++ b/runtime/websocket/bin/websocket-server-2pass.cpp
@@ -111,7 +111,9 @@
int audio_fs,
std::string wav_format,
FUNASR_HANDLE& tpass_online_handle,
- FUNASR_DEC_HANDLE& decoder_handle) {
+ FUNASR_DEC_HANDLE& decoder_handle,
+ std::string svs_lang,
+ bool sys_itn) {
// lock for each connection
if(!tpass_online_handle){
scoped_lock guard(thread_lock);
@@ -140,7 +142,8 @@
subvector.data(), subvector.size(),
punc_cache, false, audio_fs,
wav_format, (ASR_TYPE)asr_mode_,
- hotwords_embedding, itn, decoder_handle);
+ hotwords_embedding, itn, decoder_handle,
+ svs_lang, sys_itn);
} else {
scoped_lock guard(thread_lock);
@@ -177,7 +180,8 @@
buffer.data(), buffer.size(), punc_cache,
is_final, audio_fs,
wav_format, (ASR_TYPE)asr_mode_,
- hotwords_embedding, itn, decoder_handle);
+ hotwords_embedding, itn, decoder_handle,
+ svs_lang, sys_itn);
} else {
scoped_lock guard(thread_lock);
msg["access_num"]=(int)msg["access_num"]-1;
@@ -250,6 +254,8 @@
data_msg->msg["audio_fs"] = 16000; // default is 16k
data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly
data_msg->msg["is_eof"]=false; // if this connection is closed
+ data_msg->msg["svs_lang"]="auto";
+ data_msg->msg["svs_itn"]=true;
FUNASR_DEC_HANDLE decoder_handle =
FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, global_beam_, lattice_beam_, am_scale_);
data_msg->decoder_handle = decoder_handle;
@@ -475,6 +481,12 @@
if (jsonresult.contains("itn")) {
msg_data->msg["itn"] = jsonresult["itn"];
}
+ if (jsonresult.contains("svs_lang")) {
+ msg_data->msg["svs_lang"] = jsonresult["svs_lang"];
+ }
+ if (jsonresult.contains("svs_itn")) {
+ msg_data->msg["svs_itn"] = jsonresult["svs_itn"];
+ }
LOG(INFO) << "jsonresult=" << jsonresult
<< ", msg_data->msg=" << msg_data->msg;
if ((jsonresult["is_speaking"] == false ||
@@ -499,7 +511,9 @@
msg_data->msg["audio_fs"],
msg_data->msg["wav_format"],
std::ref(msg_data->tpass_online_handle),
- std::ref(msg_data->decoder_handle)));
+ std::ref(msg_data->decoder_handle),
+ msg_data->msg["svs_lang"],
+ msg_data->msg["svs_itn"]));
msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
}
catch (std::exception const &e)
@@ -547,7 +561,9 @@
msg_data->msg["audio_fs"],
msg_data->msg["wav_format"],
std::ref(msg_data->tpass_online_handle),
- std::ref(msg_data->decoder_handle)));
+ std::ref(msg_data->decoder_handle),
+ msg_data->msg["svs_lang"],
+ msg_data->msg["svs_itn"]));
msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
}
}
diff --git a/runtime/websocket/bin/websocket-server-2pass.h b/runtime/websocket/bin/websocket-server-2pass.h
index 7938f88..e61a93b 100644
--- a/runtime/websocket/bin/websocket-server-2pass.h
+++ b/runtime/websocket/bin/websocket-server-2pass.h
@@ -125,7 +125,9 @@
int audio_fs,
std::string wav_format,
FUNASR_HANDLE& tpass_online_handle,
- FUNASR_DEC_HANDLE& decoder_handle);
+ FUNASR_DEC_HANDLE& decoder_handle,
+ std::string svs_lang,
+ bool sys_itn);
void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
--
Gitblit v1.9.1