From b8825902d93d5017e44828316062dc8306b7ddcd Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期二, 26 十二月 2023 10:51:00 +0800
Subject: [PATCH] support ngram and fst hotword for 2pass-offline (#1205)
---
runtime/onnxruntime/src/funasrruntime.cpp | 17 ++++++++++++++---
1 files changed, 14 insertions(+), 3 deletions(-)
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index ccd0412..c4cb9d9 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -437,7 +437,7 @@
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished,
int sampling_rate, std::string wav_format, ASR_TYPE mode,
- const std::vector<std::vector<float>> &hw_emb, bool itn)
+ const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle)
{
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
@@ -511,7 +511,12 @@
// timestamp
std::string cur_stamp = "[";
while(audio->FetchTpass(frame) > 0){
- string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb);
+ // dec reset
+ funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
+ if (wfst_decoder){
+ wfst_decoder->StartUtterance();
+ }
+ string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
std::vector<std::string> msg_vec = funasr::split(msg, '|'); // split with timestamp
if(msg_vec.size()==0){
@@ -762,8 +767,14 @@
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
if (paraformer->lm_)
+ mm = new funasr::WfstDecoder(paraformer->lm_.get(),
+ paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+ } else if (asr_type == ASR_TWO_PASS){
+ funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
+ funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
+ if (paraformer->lm_)
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
- paraformer->GetPhoneSet(), paraformer->GetVocab(), glob_beam, lat_beam, am_scale);
+ paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
}
return mm;
}
--
Gitblit v1.9.1