From 8912e0696af069de47646fdb8a9d9c4e086e88b3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 14 一月 2024 23:42:11 +0800
Subject: [PATCH] Resolve merge conflict
---
runtime/onnxruntime/src/funasrruntime.cpp | 19 ++++++++++++++++---
1 files changed, 16 insertions(+), 3 deletions(-)
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index ccd0412..fdaf69d 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -146,6 +146,7 @@
funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
p_result->snippet_time = audio.GetTimeLen();
if(p_result->snippet_time == 0){
+ p_result->segments = new vector<std::vector<int>>();
return p_result;
}
@@ -178,6 +179,7 @@
funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
p_result->snippet_time = audio.GetTimeLen();
if(p_result->snippet_time == 0){
+ p_result->segments = new vector<std::vector<int>>();
return p_result;
}
@@ -437,7 +439,7 @@
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished,
int sampling_rate, std::string wav_format, ASR_TYPE mode,
- const std::vector<std::vector<float>> &hw_emb, bool itn)
+ const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle)
{
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
@@ -511,7 +513,12 @@
// timestamp
std::string cur_stamp = "[";
while(audio->FetchTpass(frame) > 0){
- string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb);
+ // dec reset
+ funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
+ if (wfst_decoder){
+ wfst_decoder->StartUtterance();
+ }
+ string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
std::vector<std::string> msg_vec = funasr::split(msg, '|'); // split with timestamp
if(msg_vec.size()==0){
@@ -762,8 +769,14 @@
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
if (paraformer->lm_)
+ mm = new funasr::WfstDecoder(paraformer->lm_.get(),
+ paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+ } else if (asr_type == ASR_TWO_PASS){
+ funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
+ funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
+ if (paraformer->lm_)
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
- paraformer->GetPhoneSet(), paraformer->GetVocab(), glob_beam, lat_beam, am_scale);
+ paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
}
return mm;
}
--
Gitblit v1.9.1