From b8825902d93d5017e44828316062dc8306b7ddcd Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期二, 26 十二月 2023 10:51:00 +0800
Subject: [PATCH] support ngram and fst hotword for 2pass-offline (#1205)

---
 runtime/onnxruntime/include/funasrruntime.h         |    2 
 runtime/onnxruntime/bin/funasr-onnx-2pass.cpp       |   32 +++++++
 runtime/onnxruntime/src/paraformer.h                |    2 
 runtime/websocket/bin/websocket-server-2pass.h      |    6 +
 runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp |    1 
 runtime/websocket/bin/websocket-server-2pass.cpp    |   22 ++++-
 runtime/onnxruntime/src/audio.cpp                   |    6 
 runtime/onnxruntime/src/paraformer.cpp              |   11 ++
 runtime/websocket/bin/funasr-wss-client-2pass.cpp   |    5 +
 runtime/onnxruntime/bin/funasr-onnx-offline.cpp     |    2 
 runtime/onnxruntime/src/tpass-stream.cpp            |   14 +++
 runtime/onnxruntime/src/funasrruntime.cpp           |   17 +++
 runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp   |   37 ++++++++-
 runtime/websocket/bin/funasr-wss-server-2pass.cpp   |   82 +++++++++++++++++++-
 14 files changed, 208 insertions(+), 31 deletions(-)

diff --git a/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp b/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
index e591b32..d49ba72 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
@@ -44,12 +44,17 @@
 }
 
 void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<string> wav_list, vector<string> wav_ids, int audio_fs,
-            float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_) {
+            float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_,
+            float glob_beam, float lat_beam, float am_scale, int inc_bias, unordered_map<string, int> hws_map) {
     
     struct timeval start, end;
     long seconds = 0;
     float n_total_length = 0.0f;
     long n_total_time = 0;
+
+    FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_scale);
+    // load hotwords list and build graph
+    FunWfstDecoderLoadHwsRes(decoder_handle, inc_bias, hws_map);
        
     std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS);
     
@@ -90,7 +95,8 @@
                 } else {
                     is_final = false;
             }
-            FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
+            FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, 
+                                                        sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
             if (result)
             {
                 FunASRFreeResult(result);
@@ -139,7 +145,8 @@
                     is_final = false;
             }
             gettimeofday(&start, NULL);
-            FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
+            FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, 
+                                                        sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
             gettimeofday(&end, NULL);
             seconds = (end.tv_sec - start.tv_sec);
             long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
@@ -197,6 +204,8 @@
             *total_time = n_total_time;
         }
     }
+    FunWfstDecoderUnloadHwsRes(decoder_handle);
+    FunASRWfstDecoderUninit(decoder_handle);
     FunTpassOnlineUninit(tpass_online_handle);
 }
 
@@ -215,6 +224,11 @@
     TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
     TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
     TCLAP::ValueArg<std::string>    itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
+    TCLAP::ValueArg<std::string>    lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
+    TCLAP::ValueArg<float>    global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
+    TCLAP::ValueArg<float>    lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
+    TCLAP::ValueArg<float>    am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
+    TCLAP::ValueArg<std::int32_t>   fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t");
 
     TCLAP::ValueArg<std::string>    asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
     TCLAP::ValueArg<std::int32_t>   onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
@@ -231,6 +245,11 @@
     cmd.add(punc_dir);
     cmd.add(punc_quant);
     cmd.add(itn_dir);
+    cmd.add(lm_dir);
+    cmd.add(global_beam);
+    cmd.add(lattice_beam);
+    cmd.add(am_scale);
+    cmd.add(fst_inc_wts);
     cmd.add(wav_path);
     cmd.add(audio_fs);
     cmd.add(asr_mode);
@@ -248,6 +267,7 @@
     GetValue(punc_dir, PUNC_DIR, model_path);
     GetValue(punc_quant, PUNC_QUANT, model_path);
     GetValue(itn_dir, ITN_DIR, model_path);
+    GetValue(lm_dir, LM_DIR, model_path);
     GetValue(wav_path, WAV_PATH, model_path);
     GetValue(asr_mode, ASR_MODE, model_path);
 
@@ -271,6 +291,14 @@
     {
         LOG(ERROR) << "FunTpassInit init failed";
         exit(-1);
+    }
+    float glob_beam = 3.0f;
+    float lat_beam = 3.0f;
+    float am_sc = 10.0f;
+    if (lm_dir.isSet()) {
+        glob_beam = global_beam.getValue();
+        lat_beam = lattice_beam.getValue();
+        am_sc = am_scale.getValue();
     }
 
     gettimeofday(&end, NULL);
@@ -321,7 +349,8 @@
     int rtf_threds = thread_num_.getValue();
     for (int i = 0; i < rtf_threds; i++)
     {
-        threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_));
+        threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_,
+                                    glob_beam, lat_beam, am_sc, fst_inc_wts.getValue(), hws_map));
     }
 
     for (auto& thread : threads)
diff --git a/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp b/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp
index b210927..abcc4b2 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp
@@ -51,6 +51,11 @@
     TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
     TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
     TCLAP::ValueArg<std::string>    itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string");
+    TCLAP::ValueArg<std::string>    lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
+    TCLAP::ValueArg<float>    global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
+    TCLAP::ValueArg<float>    lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
+    TCLAP::ValueArg<float>    am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
+    TCLAP::ValueArg<std::int32_t>   fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t");
     TCLAP::ValueArg<std::string>    asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
     TCLAP::ValueArg<std::int32_t>   onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
 
@@ -65,6 +70,11 @@
     cmd.add(vad_quant);
     cmd.add(punc_dir);
     cmd.add(punc_quant);
+    cmd.add(lm_dir);
+    cmd.add(global_beam);
+    cmd.add(lattice_beam);
+    cmd.add(am_scale);
+    cmd.add(fst_inc_wts);
     cmd.add(itn_dir);
     cmd.add(wav_path);
     cmd.add(audio_fs);
@@ -81,6 +91,7 @@
     GetValue(vad_quant, VAD_QUANT, model_path);
     GetValue(punc_dir, PUNC_DIR, model_path);
     GetValue(punc_quant, PUNC_QUANT, model_path);
+    GetValue(lm_dir, LM_DIR, model_path);
     GetValue(itn_dir, ITN_DIR, model_path);
     GetValue(wav_path, WAV_PATH, model_path);
     GetValue(asr_mode, ASR_MODE, model_path);
@@ -106,6 +117,16 @@
         LOG(ERROR) << "FunTpassInit init failed";
         exit(-1);
     }
+    float glob_beam = 3.0f;
+    float lat_beam = 3.0f;
+    float am_sc = 10.0f;
+    if (lm_dir.isSet()) {
+        glob_beam = global_beam.getValue();
+        lat_beam = lattice_beam.getValue();
+        am_sc = am_scale.getValue();
+    }
+    // init wfst decoder
+    FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_sc);
 
     gettimeofday(&end, NULL);
     long seconds = (end.tv_sec - start.tv_sec);
@@ -145,6 +166,9 @@
         wav_list.emplace_back(wav_path_);
         wav_ids.emplace_back(default_id);
     }
+
+    // load hotwords list and build graph
+    FunWfstDecoderLoadHwsRes(decoder_handle, fst_inc_wts.getValue(), hws_map);
 
     std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS);
     // init online features
@@ -191,7 +215,9 @@
                     is_final = false;
             }
             gettimeofday(&start, NULL);
-            FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding);
+            FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, 
+                speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", 
+                (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle);
             gettimeofday(&end, NULL);
             seconds = (end.tv_sec - start.tv_sec);
             taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
@@ -235,10 +261,12 @@
             }
         }
     }
- 
+
+    FunWfstDecoderUnloadHwsRes(decoder_handle);
     LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
     LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
     LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
+    FunASRWfstDecoderUninit(decoder_handle);
     FunTpassOnlineUninit(tpass_online_handle);
     FunTpassUninit(tpass_handle);
     return 0;
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index b1a7c87..83d7e79 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -54,7 +54,6 @@
     // warm up
     for (size_t i = 0; i < 1; i++)
     {
-        FunOfflineReset(asr_handle, decoder_handle);
         FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, audio_fs, true, decoder_handle);
         if(result){
             FunASRFreeResult(result);
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
index 55eda93..4aaa002 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -50,7 +50,7 @@
     TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
     TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
     TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
-    TCLAP::ValueArg<std::string>    lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml ", false, "", "string");
+    TCLAP::ValueArg<std::string>    lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string");
     TCLAP::ValueArg<float>    global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
     TCLAP::ValueArg<float>    lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
     TCLAP::ValueArg<float>    am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
diff --git a/runtime/onnxruntime/include/funasrruntime.h b/runtime/onnxruntime/include/funasrruntime.h
index 27ee6c6..cff617f 100644
--- a/runtime/onnxruntime/include/funasrruntime.h
+++ b/runtime/onnxruntime/include/funasrruntime.h
@@ -119,7 +119,7 @@
 _FUNASRAPI FUNASR_RESULT	FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, 
 												int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true, 
 												int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS, 
-												const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true);
+												const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
 _FUNASRAPI void				FunTpassUninit(FUNASR_HANDLE handle);
 _FUNASRAPI void				FunTpassOnlineUninit(FUNASR_HANDLE handle);
 
diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp
index 559e3dd..c471329 100644
--- a/runtime/onnxruntime/src/audio.cpp
+++ b/runtime/onnxruntime/src/audio.cpp
@@ -254,9 +254,9 @@
 void Audio::WavResample(int32_t sampling_rate, const float *waveform,
                           int32_t n)
 {
-    LOG(INFO) << "Creating a resampler:\n"
-              << "   in_sample_rate: "<< sampling_rate << "\n"
-              << "   output_sample_rate: " << static_cast<int32_t>(dest_sample_rate);
+    LOG(INFO) << "Creating a resampler: "
+              << " in_sample_rate: "<< sampling_rate
+              << " output_sample_rate: " << static_cast<int32_t>(dest_sample_rate);
     float min_freq =
         std::min<int32_t>(sampling_rate, dest_sample_rate);
     float lowpass_cutoff = 0.99 * 0.5 * min_freq;
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index ccd0412..c4cb9d9 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -437,7 +437,7 @@
 	_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, 
 												 int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, 
 												 int sampling_rate, std::string wav_format, ASR_TYPE mode, 
-												 const std::vector<std::vector<float>> &hw_emb, bool itn)
+												 const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle)
 	{
 		funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
 		funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
@@ -511,7 +511,12 @@
 		// timestamp
 		std::string cur_stamp = "[";		
 		while(audio->FetchTpass(frame) > 0){
-			string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb);
+			// dec reset
+			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
+			if (wfst_decoder){
+				wfst_decoder->StartUtterance();
+			}
+			string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
 
 			std::vector<std::string> msg_vec = funasr::split(msg, '|');  // split with timestamp
 			if(msg_vec.size()==0){
@@ -762,8 +767,14 @@
 			funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
 			funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
 			if (paraformer->lm_)
+				mm = new funasr::WfstDecoder(paraformer->lm_.get(),
+					paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+		} else if (asr_type == ASR_TWO_PASS){
+			funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
+			funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
+			if (paraformer->lm_)
 				mm = new funasr::WfstDecoder(paraformer->lm_.get(), 
-					paraformer->GetPhoneSet(), paraformer->GetVocab(), glob_beam, lat_beam, am_scale);
+					paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
 		}
 		return mm;
 	}
diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index bb15ac7..c56421c 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -193,8 +193,7 @@
         lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
             fst::Fst<fst::StdArc>::Read(lm_file));
         if (lm_){
-            if (vocab) { delete vocab; }
-            vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
+            lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
             LOG(INFO) << "Successfully load lm file " << lm_file;
         }else{
             LOG(ERROR) << "Failed to load lm file " << lm_file;
@@ -309,6 +308,9 @@
 {
     if(vocab){
         delete vocab;
+    }
+    if(lm_vocab){
+        delete lm_vocab;
     }
     if(seg_dict){
         delete seg_dict;
@@ -687,6 +689,11 @@
     return vocab;
 }
 
+Vocab* Paraformer::GetLmVocab()
+{
+    return lm_vocab;
+}
+
 PhoneSet* Paraformer::GetPhoneSet()
 {
     return phone_set_;
diff --git a/runtime/onnxruntime/src/paraformer.h b/runtime/onnxruntime/src/paraformer.h
index de05657..5bb9477 100644
--- a/runtime/onnxruntime/src/paraformer.h
+++ b/runtime/onnxruntime/src/paraformer.h
@@ -20,6 +20,7 @@
     */
     private:
         Vocab* vocab = nullptr;
+        Vocab* lm_vocab = nullptr;
         SegDict* seg_dict = nullptr;
         PhoneSet* phone_set_ = nullptr;
         //const float scale = 22.6274169979695;
@@ -65,6 +66,7 @@
         string FinalizeDecode(WfstDecoder* &wfst_decoder,
                           bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
         Vocab* GetVocab();
+        Vocab* GetLmVocab();
         PhoneSet* GetPhoneSet();
 		
         knf::FbankOptions fbank_opts_;
diff --git a/runtime/onnxruntime/src/tpass-stream.cpp b/runtime/onnxruntime/src/tpass-stream.cpp
index a3e1b0e..b723e0f 100644
--- a/runtime/onnxruntime/src/tpass-stream.cpp
+++ b/runtime/onnxruntime/src/tpass-stream.cpp
@@ -66,6 +66,20 @@
         LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
         exit(-1);
     }
+    
+    // Lm resource
+    if (model_path.find(LM_DIR) != model_path.end() && model_path.at(LM_DIR) != "") {
+        string fst_path, lm_config_path, lex_path;
+        fst_path = PathAppend(model_path.at(LM_DIR), LM_FST_RES);
+        lm_config_path = PathAppend(model_path.at(LM_DIR), LM_CONFIG_NAME);
+        lex_path = PathAppend(model_path.at(LM_DIR), LEX_PATH);
+        if (access(lex_path.c_str(), F_OK) != 0 )
+        {
+            LOG(ERROR) << "Lexicon.txt file is not exist, please use the latest version. Skip load LM model.";
+        }else{
+            asr_handle->InitLm(fst_path, lm_config_path, lex_path);
+        }
+    }
 
     // PUNC model
     if(model_path.find(PUNC_DIR) != model_path.end()){
diff --git a/runtime/websocket/bin/funasr-wss-client-2pass.cpp b/runtime/websocket/bin/funasr-wss-client-2pass.cpp
index 6533dd5..0cbd10e 100644
--- a/runtime/websocket/bin/funasr-wss-client-2pass.cpp
+++ b/runtime/websocket/bin/funasr-wss-client-2pass.cpp
@@ -192,7 +192,10 @@
     funasr::Audio audio(1);
     int32_t sampling_rate = audio_fs;
     std::string wav_format = "pcm";
-    if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) {
+    if (funasr::IsTargetFile(wav_path.c_str(), "wav")) {
+      if (!audio.LoadWav(wav_path.c_str(), &sampling_rate, false)) 
+        return;
+    } else if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) {
       if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate, false)) return;
     } else {
       wav_format = "others";
diff --git a/runtime/websocket/bin/funasr-wss-server-2pass.cpp b/runtime/websocket/bin/funasr-wss-server-2pass.cpp
index 965f2a8..ef27d5b 100644
--- a/runtime/websocket/bin/funasr-wss-server-2pass.cpp
+++ b/runtime/websocket/bin/funasr-wss-server-2pass.cpp
@@ -16,6 +16,7 @@
 // hotwords
 std::unordered_map<std::string, int> hws_map_;
 int fst_inc_wts_=20;
+float global_beam_, lattice_beam_, am_scale_;
 
 using namespace std;
 void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
@@ -120,6 +121,14 @@
         "connection",
         false, "../../../ssl_key/server.key", "string");
 
+    TCLAP::ValueArg<float>    global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float");
+    TCLAP::ValueArg<float>    lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float");
+    TCLAP::ValueArg<float>    am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float");
+
+    TCLAP::ValueArg<std::string> lm_dir("", LM_DIR,
+        "the LM model path, which contains compiled models: TLG.fst, config.yaml ", false, "damo/speech_ngram_lm_zh-cn-ai-wesp-fst", "string");
+    TCLAP::ValueArg<std::string> lm_revision(
+        "", "lm-revision", "LM model revision", false, "v1.0.2", "string");
     TCLAP::ValueArg<std::string> hotword("", HOTWORD,
         "the hotword file, one hotword perline, Format: Hotword Weight (could be: 闃块噷宸村反 20)", 
         false, "/workspace/resources/hotwords.txt", "string");
@@ -128,6 +137,10 @@
 
     // add file
     cmd.add(hotword);
+    cmd.add(fst_inc_wts);
+    cmd.add(global_beam);
+    cmd.add(lattice_beam);
+    cmd.add(am_scale);
 
     cmd.add(certfile);
     cmd.add(keyfile);
@@ -146,6 +159,8 @@
     cmd.add(punc_quant);
     cmd.add(itn_dir);
     cmd.add(itn_revision);
+    cmd.add(lm_dir);
+    cmd.add(lm_revision);
 
     cmd.add(listen_ip);
     cmd.add(port);
@@ -163,6 +178,7 @@
     GetValue(punc_dir, PUNC_DIR, model_path);
     GetValue(punc_quant, PUNC_QUANT, model_path);
     GetValue(itn_dir, ITN_DIR, model_path);
+    GetValue(lm_dir, LM_DIR, model_path);
     GetValue(hotword, HOTWORD, model_path);
 
     GetValue(offline_model_revision, "offline-model-revision", model_path);
@@ -170,6 +186,11 @@
     GetValue(vad_revision, "vad-revision", model_path);
     GetValue(punc_revision, "punc-revision", model_path);
     GetValue(itn_revision, "itn-revision", model_path);
+    GetValue(lm_revision, "lm-revision", model_path);
+
+    global_beam_ = global_beam.getValue();
+    lattice_beam_ = lattice_beam.getValue();
+    am_scale_ = am_scale.getValue();
 
     // Download model form Modelscope
     try {
@@ -183,6 +204,7 @@
       std::string s_punc_path = model_path[PUNC_DIR];
       std::string s_punc_quant = model_path[PUNC_QUANT];
       std::string s_itn_path = model_path[ITN_DIR];
+      std::string s_lm_path = model_path[LM_DIR];
 
       std::string python_cmd =
           "python -m funasr.utils.runtime_sdk_download_tool --type onnx --quantize True ";
@@ -241,11 +263,18 @@
         size_t found = s_offline_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
         if (found != std::string::npos) {
             model_path["offline-model-revision"]="v1.2.4";
-        } else{
-            found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
-            if (found != std::string::npos) {
-                model_path["offline-model-revision"]="v1.0.5";
-            }
+        }
+
+        found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
+        if (found != std::string::npos) {
+            model_path["offline-model-revision"]="v1.0.5";
+        }
+
+        found = s_offline_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
+        if (found != std::string::npos) {
+            model_path["model-revision"]="v1.0.0";
+            s_itn_path="";
+            s_lm_path="";
         }
 
         if (access(s_offline_asr_path.c_str(), F_OK) == 0) {
@@ -332,6 +361,49 @@
         LOG(INFO) << "ASR online model is not set, use default.";
       }
 
+      if (!s_lm_path.empty() && s_lm_path != "NONE" && s_lm_path != "none") {
+          std::string python_cmd_lm;
+          std::string down_lm_path;
+          std::string down_lm_model;
+
+          if (access(s_lm_path.c_str(), F_OK) == 0) {
+              // local
+              python_cmd_lm = python_cmd + " --model-name " + s_lm_path +
+                                  " --export-dir ./ " + " --model_revision " +
+                                  model_path["lm-revision"] + " --export False ";
+              down_lm_path = s_lm_path;
+          } else {
+              // modelscope
+              LOG(INFO) << "Download model: " << s_lm_path
+                          << " from modelscope : "; 
+              python_cmd_lm = python_cmd + " --model-name " +
+                      s_lm_path +
+                      " --export-dir " + s_download_model_dir +
+                      " --model_revision " + model_path["lm-revision"]
+                      + " --export False "; 
+              down_lm_path  =
+                      s_download_model_dir +
+                      "/" + s_lm_path;
+          }
+
+          int ret = system(python_cmd_lm.c_str());
+          if (ret != 0) {
+              LOG(INFO) << "Failed to download model from modelscope. If you set local lm model path, you can ignore the errors.";
+          }
+          down_lm_model = down_lm_path + "/TLG.fst";
+
+          if (access(down_lm_model.c_str(), F_OK) != 0) {
+              LOG(ERROR) << down_lm_model << " do not exists.";
+              exit(-1);
+          } else {
+              model_path[LM_DIR] = down_lm_path;
+              LOG(INFO) << "Set " << LM_DIR << " : " << model_path[LM_DIR];
+          }
+      } else {
+          LOG(INFO) << "LM model is not set, not executed.";
+          model_path[LM_DIR] = "";
+      }
+
       if (!s_punc_path.empty()) {
         std::string python_cmd_punc;
         std::string down_punc_path;
diff --git a/runtime/websocket/bin/websocket-server-2pass.cpp b/runtime/websocket/bin/websocket-server-2pass.cpp
index 44dd82e..0269e5f 100644
--- a/runtime/websocket/bin/websocket-server-2pass.cpp
+++ b/runtime/websocket/bin/websocket-server-2pass.cpp
@@ -18,6 +18,7 @@
 
 extern std::unordered_map<std::string, int> hws_map_;
 extern int fst_inc_wts_;
+extern float global_beam_, lattice_beam_, am_scale_;
 
 context_ptr WebSocketServer::on_tls_init(tls_mode mode,
                                          websocketpp::connection_hdl hdl,
@@ -102,7 +103,8 @@
     bool itn,
     int audio_fs,
     std::string wav_format,
-    FUNASR_HANDLE& tpass_online_handle) {
+    FUNASR_HANDLE& tpass_online_handle,
+    FUNASR_DEC_HANDLE& decoder_handle) {
   // lock for each connection
   if(!tpass_online_handle){
     scoped_lock guard(thread_lock);
@@ -131,7 +133,7 @@
                                        subvector.data(), subvector.size(),
                                        punc_cache, false, audio_fs,
                                        wav_format, (ASR_TYPE)asr_mode_,
-                                       hotwords_embedding, itn);
+                                       hotwords_embedding, itn, decoder_handle);
 
         } else {
           scoped_lock guard(thread_lock);
@@ -168,7 +170,7 @@
                                        buffer.data(), buffer.size(), punc_cache,
                                        is_final, audio_fs,
                                        wav_format, (ASR_TYPE)asr_mode_,
-                                       hotwords_embedding, itn);
+                                       hotwords_embedding, itn, decoder_handle);
         } else {
           scoped_lock guard(thread_lock);
           msg["access_num"]=(int)msg["access_num"]-1;	 
@@ -241,6 +243,9 @@
     data_msg->msg["audio_fs"] = 16000; // default is 16k
     data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly
     data_msg->msg["is_eof"]=false; // if this connection is closed
+    FUNASR_DEC_HANDLE decoder_handle =
+      FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, global_beam_, lattice_beam_, am_scale_);
+    data_msg->decoder_handle = decoder_handle;
     data_msg->punc_cache =
         std::make_shared<std::vector<std::vector<std::string>>>(2);
   	data_msg->strand_ =	std::make_shared<asio::io_context::strand>(io_decoder_);
@@ -267,6 +272,9 @@
   // finished and avoid access freed tpass_online_handle
   unique_lock guard_decoder(*(data_msg->thread_lock));
   if (data_msg->msg["access_num"]==0 && data_msg->msg["is_eof"]==true) {
+    FunWfstDecoderUnloadHwsRes(data_msg->decoder_handle);
+    FunASRWfstDecoderUninit(data_msg->decoder_handle);
+    data_msg->decoder_handle = nullptr;
     FunTpassOnlineUninit(data_msg->tpass_online_handle);
     data_msg->tpass_online_handle = nullptr;
 	  data_map.erase(hdl);
@@ -431,7 +439,7 @@
             nn_hotwords += " " + pair.first;
             LOG(INFO) << pair.first << " : " << pair.second;
         }
-        // FunWfstDecoderLoadHwsRes(msg_data->decoder_handle, fst_inc_wts_, merged_hws_map);
+        FunWfstDecoderLoadHwsRes(msg_data->decoder_handle, fst_inc_wts_, merged_hws_map);
 
         // nn
         std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords, ASR_TWO_PASS);
@@ -483,7 +491,8 @@
                         msg_data->msg["itn"],
                         msg_data->msg["audio_fs"],
                         msg_data->msg["wav_format"],
-                        std::ref(msg_data->tpass_online_handle)));
+                        std::ref(msg_data->tpass_online_handle),
+                        std::ref(msg_data->decoder_handle)));
 		      msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
         }
         catch (std::exception const &e)
@@ -530,7 +539,8 @@
                                   msg_data->msg["itn"],
                                   msg_data->msg["audio_fs"],
                                   msg_data->msg["wav_format"],
-                                  std::ref(msg_data->tpass_online_handle)));
+                                  std::ref(msg_data->tpass_online_handle),
+                                  std::ref(msg_data->decoder_handle)));
               msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
             }
           }
diff --git a/runtime/websocket/bin/websocket-server-2pass.h b/runtime/websocket/bin/websocket-server-2pass.h
index 3e78a34..6b2ba32 100644
--- a/runtime/websocket/bin/websocket-server-2pass.h
+++ b/runtime/websocket/bin/websocket-server-2pass.h
@@ -60,7 +60,8 @@
   FUNASR_HANDLE tpass_online_handle=NULL;
   std::string online_res = "";
   std::string tpass_res = "";
-  std::shared_ptr<asio::io_context::strand>  strand_; // for data execute in order 
+  std::shared_ptr<asio::io_context::strand>  strand_; // for data execute in order
+  FUNASR_DEC_HANDLE decoder_handle=NULL; 
 } FUNASR_MESSAGE;
 
 // See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
@@ -123,7 +124,8 @@
                   bool itn,
                   int audio_fs,
                   std::string wav_format,
-                  FUNASR_HANDLE& tpass_online_handle);
+                  FUNASR_HANDLE& tpass_online_handle,
+                  FUNASR_DEC_HANDLE& decoder_handle);
 
   void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
   void on_message(websocketpp::connection_hdl hdl, message_ptr msg);

--
Gitblit v1.9.1