support ngram and fst hotword for 2pass-offline (#1205)
| | |
| | | } |
| | | |
| | | void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<string> wav_list, vector<string> wav_ids, int audio_fs, |
| | | float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_) { |
| | | float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_, string nn_hotwords_, |
| | | float glob_beam, float lat_beam, float am_scale, int inc_bias, unordered_map<string, int> hws_map) { |
| | | |
| | | struct timeval start, end; |
| | | long seconds = 0; |
| | | float n_total_length = 0.0f; |
| | | long n_total_time = 0; |
| | | |
| | | FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_scale); |
| | | // load hotwords list and build graph |
| | | FunWfstDecoderLoadHwsRes(decoder_handle, inc_bias, hws_map); |
| | | |
| | | std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS); |
| | | |
| | |
| | | } else { |
| | | is_final = false; |
| | | } |
| | | FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding); |
| | | FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, |
| | | sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle); |
| | | if (result) |
| | | { |
| | | FunASRFreeResult(result); |
| | |
| | | is_final = false; |
| | | } |
| | | gettimeofday(&start, NULL); |
| | | FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding); |
| | | FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, |
| | | sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle); |
| | | gettimeofday(&end, NULL); |
| | | seconds = (end.tv_sec - start.tv_sec); |
| | | long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); |
| | |
| | | *total_time = n_total_time; |
| | | } |
| | | } |
| | | FunWfstDecoderUnloadHwsRes(decoder_handle); |
| | | FunASRWfstDecoderUninit(decoder_handle); |
| | | FunTpassOnlineUninit(tpass_online_handle); |
| | | } |
| | | |
| | |
| | | TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string"); |
| | | TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float"); |
| | | TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float"); |
| | | TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float"); |
| | | TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t"); |
| | | |
| | | TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string"); |
| | | TCLAP::ValueArg<std::int32_t> onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t"); |
| | |
| | | cmd.add(punc_dir); |
| | | cmd.add(punc_quant); |
| | | cmd.add(itn_dir); |
| | | cmd.add(lm_dir); |
| | | cmd.add(global_beam); |
| | | cmd.add(lattice_beam); |
| | | cmd.add(am_scale); |
| | | cmd.add(fst_inc_wts); |
| | | cmd.add(wav_path); |
| | | cmd.add(audio_fs); |
| | | cmd.add(asr_mode); |
| | |
| | | GetValue(punc_dir, PUNC_DIR, model_path); |
| | | GetValue(punc_quant, PUNC_QUANT, model_path); |
| | | GetValue(itn_dir, ITN_DIR, model_path); |
| | | GetValue(lm_dir, LM_DIR, model_path); |
| | | GetValue(wav_path, WAV_PATH, model_path); |
| | | GetValue(asr_mode, ASR_MODE, model_path); |
| | | |
| | |
| | | { |
| | | LOG(ERROR) << "FunTpassInit init failed"; |
| | | exit(-1); |
| | | } |
| | | float glob_beam = 3.0f; |
| | | float lat_beam = 3.0f; |
| | | float am_sc = 10.0f; |
| | | if (lm_dir.isSet()) { |
| | | glob_beam = global_beam.getValue(); |
| | | lat_beam = lattice_beam.getValue(); |
| | | am_sc = am_scale.getValue(); |
| | | } |
| | | |
| | | gettimeofday(&end, NULL); |
| | |
| | | int rtf_threds = thread_num_.getValue(); |
| | | for (int i = 0; i < rtf_threds; i++) |
| | | { |
| | | threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_)); |
| | | threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, audio_fs.getValue(), &total_length, &total_time, i, (ASR_TYPE)asr_mode_, nn_hotwords_, |
| | | glob_beam, lat_beam, am_sc, fst_inc_wts.getValue(), hws_map)); |
| | | } |
| | | |
| | | for (auto& thread : threads) |
| | |
| | | TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> itn_dir("", ITN_DIR, "the itn model(fst) path, which contains zh_itn_tagger.fst and zh_itn_verbalizer.fst", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string"); |
| | | TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float"); |
| | | TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float"); |
| | | TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float"); |
| | | TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS, "the fst hotwords incremental bias", false, 20, "int32_t"); |
| | | TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string"); |
| | | TCLAP::ValueArg<std::int32_t> onnx_thread("", "model-thread-num", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t"); |
| | | |
| | |
| | | cmd.add(vad_quant); |
| | | cmd.add(punc_dir); |
| | | cmd.add(punc_quant); |
| | | cmd.add(lm_dir); |
| | | cmd.add(global_beam); |
| | | cmd.add(lattice_beam); |
| | | cmd.add(am_scale); |
| | | cmd.add(fst_inc_wts); |
| | | cmd.add(itn_dir); |
| | | cmd.add(wav_path); |
| | | cmd.add(audio_fs); |
| | |
| | | GetValue(vad_quant, VAD_QUANT, model_path); |
| | | GetValue(punc_dir, PUNC_DIR, model_path); |
| | | GetValue(punc_quant, PUNC_QUANT, model_path); |
| | | GetValue(lm_dir, LM_DIR, model_path); |
| | | GetValue(itn_dir, ITN_DIR, model_path); |
| | | GetValue(wav_path, WAV_PATH, model_path); |
| | | GetValue(asr_mode, ASR_MODE, model_path); |
| | |
| | | LOG(ERROR) << "FunTpassInit init failed"; |
| | | exit(-1); |
| | | } |
| | | float glob_beam = 3.0f; |
| | | float lat_beam = 3.0f; |
| | | float am_sc = 10.0f; |
| | | if (lm_dir.isSet()) { |
| | | glob_beam = global_beam.getValue(); |
| | | lat_beam = lattice_beam.getValue(); |
| | | am_sc = am_scale.getValue(); |
| | | } |
| | | // init wfst decoder |
| | | FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, glob_beam, lat_beam, am_sc); |
| | | |
| | | gettimeofday(&end, NULL); |
| | | long seconds = (end.tv_sec - start.tv_sec); |
| | |
| | | wav_list.emplace_back(wav_path_); |
| | | wav_ids.emplace_back(default_id); |
| | | } |
| | | |
| | | // load hotwords list and build graph |
| | | FunWfstDecoderLoadHwsRes(decoder_handle, fst_inc_wts.getValue(), hws_map); |
| | | |
| | | std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords_, ASR_TWO_PASS); |
| | | // init online features |
| | |
| | | is_final = false; |
| | | } |
| | | gettimeofday(&start, NULL); |
| | | FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_, hotwords_embedding); |
| | | FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, |
| | | speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", |
| | | (ASR_TYPE)asr_mode_, hotwords_embedding, true, decoder_handle); |
| | | gettimeofday(&end, NULL); |
| | | seconds = (end.tv_sec - start.tv_sec); |
| | | taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); |
| | |
| | | } |
| | | } |
| | | } |
| | | |
| | | |
| | | FunWfstDecoderUnloadHwsRes(decoder_handle); |
| | | LOG(INFO) << "Audio length: " << (double)snippet_time << " s"; |
| | | LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s"; |
| | | LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000); |
| | | FunASRWfstDecoderUninit(decoder_handle); |
| | | FunTpassOnlineUninit(tpass_online_handle); |
| | | FunTpassUninit(tpass_handle); |
| | | return 0; |
| | |
| | | // warm up |
| | | for (size_t i = 0; i < 1; i++) |
| | | { |
| | | FunOfflineReset(asr_handle, decoder_handle); |
| | | FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, audio_fs, true, decoder_handle); |
| | | if(result){ |
| | | FunASRFreeResult(result); |
| | |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml ", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, "the lm model path, which contains compiled models: TLG.fst, config.yaml, lexicon.txt ", false, "", "string"); |
| | | TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float"); |
| | | TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float"); |
| | | TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float"); |
| | |
| | | _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, |
| | | int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true, |
| | | int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS, |
| | | const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true); |
| | | const std::vector<std::vector<float>> &hw_emb={{0.0}}, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr); |
| | | _FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle); |
| | | _FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle); |
| | | |
| | |
| | | void Audio::WavResample(int32_t sampling_rate, const float *waveform, |
| | | int32_t n) |
| | | { |
| | | LOG(INFO) << "Creating a resampler:\n" |
| | | << " in_sample_rate: "<< sampling_rate << "\n" |
| | | << " output_sample_rate: " << static_cast<int32_t>(dest_sample_rate); |
| | | LOG(INFO) << "Creating a resampler: " |
| | | << " in_sample_rate: "<< sampling_rate |
| | | << " output_sample_rate: " << static_cast<int32_t>(dest_sample_rate); |
| | | float min_freq = |
| | | std::min<int32_t>(sampling_rate, dest_sample_rate); |
| | | float lowpass_cutoff = 0.99 * 0.5 * min_freq; |
| | |
| | | _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, |
| | | int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, |
| | | int sampling_rate, std::string wav_format, ASR_TYPE mode, |
| | | const std::vector<std::vector<float>> &hw_emb, bool itn) |
| | | const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle) |
| | | { |
| | | funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle; |
| | | funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle; |
| | |
| | | // timestamp |
| | | std::string cur_stamp = "["; |
| | | while(audio->FetchTpass(frame) > 0){ |
| | | string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb); |
| | | // dec reset |
| | | funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle; |
| | | if (wfst_decoder){ |
| | | wfst_decoder->StartUtterance(); |
| | | } |
| | | string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle); |
| | | |
| | | std::vector<std::string> msg_vec = funasr::split(msg, '|'); // split with timestamp |
| | | if(msg_vec.size()==0){ |
| | |
| | | funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle; |
| | | funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get(); |
| | | if (paraformer->lm_) |
| | | mm = new funasr::WfstDecoder(paraformer->lm_.get(), |
| | | paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale); |
| | | } else if (asr_type == ASR_TWO_PASS){ |
| | | funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle; |
| | | funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get(); |
| | | if (paraformer->lm_) |
| | | mm = new funasr::WfstDecoder(paraformer->lm_.get(), |
| | | paraformer->GetPhoneSet(), paraformer->GetVocab(), glob_beam, lat_beam, am_scale); |
| | | paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale); |
| | | } |
| | | return mm; |
| | | } |
| | |
| | | lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>( |
| | | fst::Fst<fst::StdArc>::Read(lm_file)); |
| | | if (lm_){ |
| | | if (vocab) { delete vocab; } |
| | | vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str()); |
| | | lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str()); |
| | | LOG(INFO) << "Successfully load lm file " << lm_file; |
| | | }else{ |
| | | LOG(ERROR) << "Failed to load lm file " << lm_file; |
| | |
| | | { |
| | | if(vocab){ |
| | | delete vocab; |
| | | } |
| | | if(lm_vocab){ |
| | | delete lm_vocab; |
| | | } |
| | | if(seg_dict){ |
| | | delete seg_dict; |
| | |
| | | return vocab; |
| | | } |
| | | |
| | | Vocab* Paraformer::GetLmVocab() |
| | | { |
| | | return lm_vocab; |
| | | } |
| | | |
| | | PhoneSet* Paraformer::GetPhoneSet() |
| | | { |
| | | return phone_set_; |
| | |
| | | */ |
| | | private: |
| | | Vocab* vocab = nullptr; |
| | | Vocab* lm_vocab = nullptr; |
| | | SegDict* seg_dict = nullptr; |
| | | PhoneSet* phone_set_ = nullptr; |
| | | //const float scale = 22.6274169979695; |
| | |
| | | string FinalizeDecode(WfstDecoder* &wfst_decoder, |
| | | bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0}); |
| | | Vocab* GetVocab(); |
| | | Vocab* GetLmVocab(); |
| | | PhoneSet* GetPhoneSet(); |
| | | |
| | | knf::FbankOptions fbank_opts_; |
| | |
| | | LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir"; |
| | | exit(-1); |
| | | } |
| | | |
| | | // Lm resource |
| | | if (model_path.find(LM_DIR) != model_path.end() && model_path.at(LM_DIR) != "") { |
| | | string fst_path, lm_config_path, lex_path; |
| | | fst_path = PathAppend(model_path.at(LM_DIR), LM_FST_RES); |
| | | lm_config_path = PathAppend(model_path.at(LM_DIR), LM_CONFIG_NAME); |
| | | lex_path = PathAppend(model_path.at(LM_DIR), LEX_PATH); |
| | | if (access(lex_path.c_str(), F_OK) != 0 ) |
| | | { |
| | | LOG(ERROR) << "Lexicon.txt file is not exist, please use the latest version. Skip load LM model."; |
| | | }else{ |
| | | asr_handle->InitLm(fst_path, lm_config_path, lex_path); |
| | | } |
| | | } |
| | | |
| | | // PUNC model |
| | | if(model_path.find(PUNC_DIR) != model_path.end()){ |
| | |
| | | funasr::Audio audio(1); |
| | | int32_t sampling_rate = audio_fs; |
| | | std::string wav_format = "pcm"; |
| | | if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) { |
| | | if (funasr::IsTargetFile(wav_path.c_str(), "wav")) { |
| | | if (!audio.LoadWav(wav_path.c_str(), &sampling_rate, false)) |
| | | return; |
| | | } else if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) { |
| | | if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate, false)) return; |
| | | } else { |
| | | wav_format = "others"; |
| | |
| | | // hotwords |
| | | std::unordered_map<std::string, int> hws_map_; |
| | | int fst_inc_wts_=20; |
| | | float global_beam_, lattice_beam_, am_scale_; |
| | | |
| | | using namespace std; |
| | | void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, |
| | |
| | | "connection", |
| | | false, "../../../ssl_key/server.key", "string"); |
| | | |
| | | TCLAP::ValueArg<float> global_beam("", GLOB_BEAM, "the decoding beam for beam searching ", false, 3.0, "float"); |
| | | TCLAP::ValueArg<float> lattice_beam("", LAT_BEAM, "the lattice generation beam for beam searching ", false, 3.0, "float"); |
| | | TCLAP::ValueArg<float> am_scale("", AM_SCALE, "the acoustic scale for beam searching ", false, 10.0, "float"); |
| | | |
| | | TCLAP::ValueArg<std::string> lm_dir("", LM_DIR, |
| | | "the LM model path, which contains compiled models: TLG.fst, config.yaml ", false, "damo/speech_ngram_lm_zh-cn-ai-wesp-fst", "string"); |
| | | TCLAP::ValueArg<std::string> lm_revision( |
| | | "", "lm-revision", "LM model revision", false, "v1.0.2", "string"); |
| | | TCLAP::ValueArg<std::string> hotword("", HOTWORD, |
| | | "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", |
| | | false, "/workspace/resources/hotwords.txt", "string"); |
| | |
| | | |
| | | // add file |
| | | cmd.add(hotword); |
| | | cmd.add(fst_inc_wts); |
| | | cmd.add(global_beam); |
| | | cmd.add(lattice_beam); |
| | | cmd.add(am_scale); |
| | | |
| | | cmd.add(certfile); |
| | | cmd.add(keyfile); |
| | |
| | | cmd.add(punc_quant); |
| | | cmd.add(itn_dir); |
| | | cmd.add(itn_revision); |
| | | cmd.add(lm_dir); |
| | | cmd.add(lm_revision); |
| | | |
| | | cmd.add(listen_ip); |
| | | cmd.add(port); |
| | |
| | | GetValue(punc_dir, PUNC_DIR, model_path); |
| | | GetValue(punc_quant, PUNC_QUANT, model_path); |
| | | GetValue(itn_dir, ITN_DIR, model_path); |
| | | GetValue(lm_dir, LM_DIR, model_path); |
| | | GetValue(hotword, HOTWORD, model_path); |
| | | |
| | | GetValue(offline_model_revision, "offline-model-revision", model_path); |
| | |
| | | GetValue(vad_revision, "vad-revision", model_path); |
| | | GetValue(punc_revision, "punc-revision", model_path); |
| | | GetValue(itn_revision, "itn-revision", model_path); |
| | | GetValue(lm_revision, "lm-revision", model_path); |
| | | |
| | | global_beam_ = global_beam.getValue(); |
| | | lattice_beam_ = lattice_beam.getValue(); |
| | | am_scale_ = am_scale.getValue(); |
| | | |
| | | // Download model form Modelscope |
| | | try { |
| | |
| | | std::string s_punc_path = model_path[PUNC_DIR]; |
| | | std::string s_punc_quant = model_path[PUNC_QUANT]; |
| | | std::string s_itn_path = model_path[ITN_DIR]; |
| | | std::string s_lm_path = model_path[LM_DIR]; |
| | | |
| | | std::string python_cmd = |
| | | "python -m funasr.utils.runtime_sdk_download_tool --type onnx --quantize True "; |
| | |
| | | size_t found = s_offline_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404"); |
| | | if (found != std::string::npos) { |
| | | model_path["offline-model-revision"]="v1.2.4"; |
| | | } else{ |
| | | found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"); |
| | | if (found != std::string::npos) { |
| | | model_path["offline-model-revision"]="v1.0.5"; |
| | | } |
| | | } |
| | | |
| | | found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"); |
| | | if (found != std::string::npos) { |
| | | model_path["offline-model-revision"]="v1.0.5"; |
| | | } |
| | | |
| | | found = s_offline_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020"); |
| | | if (found != std::string::npos) { |
| | | model_path["model-revision"]="v1.0.0"; |
| | | s_itn_path=""; |
| | | s_lm_path=""; |
| | | } |
| | | |
| | | if (access(s_offline_asr_path.c_str(), F_OK) == 0) { |
| | |
| | | LOG(INFO) << "ASR online model is not set, use default."; |
| | | } |
| | | |
| | | if (!s_lm_path.empty() && s_lm_path != "NONE" && s_lm_path != "none") { |
| | | std::string python_cmd_lm; |
| | | std::string down_lm_path; |
| | | std::string down_lm_model; |
| | | |
| | | if (access(s_lm_path.c_str(), F_OK) == 0) { |
| | | // local |
| | | python_cmd_lm = python_cmd + " --model-name " + s_lm_path + |
| | | " --export-dir ./ " + " --model_revision " + |
| | | model_path["lm-revision"] + " --export False "; |
| | | down_lm_path = s_lm_path; |
| | | } else { |
| | | // modelscope |
| | | LOG(INFO) << "Download model: " << s_lm_path |
| | | << " from modelscope : "; |
| | | python_cmd_lm = python_cmd + " --model-name " + |
| | | s_lm_path + |
| | | " --export-dir " + s_download_model_dir + |
| | | " --model_revision " + model_path["lm-revision"] |
| | | + " --export False "; |
| | | down_lm_path = |
| | | s_download_model_dir + |
| | | "/" + s_lm_path; |
| | | } |
| | | |
| | | int ret = system(python_cmd_lm.c_str()); |
| | | if (ret != 0) { |
| | | LOG(INFO) << "Failed to download model from modelscope. If you set local lm model path, you can ignore the errors."; |
| | | } |
| | | down_lm_model = down_lm_path + "/TLG.fst"; |
| | | |
| | | if (access(down_lm_model.c_str(), F_OK) != 0) { |
| | | LOG(ERROR) << down_lm_model << " do not exists."; |
| | | exit(-1); |
| | | } else { |
| | | model_path[LM_DIR] = down_lm_path; |
| | | LOG(INFO) << "Set " << LM_DIR << " : " << model_path[LM_DIR]; |
| | | } |
| | | } else { |
| | | LOG(INFO) << "LM model is not set, not executed."; |
| | | model_path[LM_DIR] = ""; |
| | | } |
| | | |
| | | if (!s_punc_path.empty()) { |
| | | std::string python_cmd_punc; |
| | | std::string down_punc_path; |
| | |
| | | |
| | | extern std::unordered_map<std::string, int> hws_map_; |
| | | extern int fst_inc_wts_; |
| | | extern float global_beam_, lattice_beam_, am_scale_; |
| | | |
| | | context_ptr WebSocketServer::on_tls_init(tls_mode mode, |
| | | websocketpp::connection_hdl hdl, |
| | |
| | | bool itn, |
| | | int audio_fs, |
| | | std::string wav_format, |
| | | FUNASR_HANDLE& tpass_online_handle) { |
| | | FUNASR_HANDLE& tpass_online_handle, |
| | | FUNASR_DEC_HANDLE& decoder_handle) { |
| | | // lock for each connection |
| | | if(!tpass_online_handle){ |
| | | scoped_lock guard(thread_lock); |
| | |
| | | subvector.data(), subvector.size(), |
| | | punc_cache, false, audio_fs, |
| | | wav_format, (ASR_TYPE)asr_mode_, |
| | | hotwords_embedding, itn); |
| | | hotwords_embedding, itn, decoder_handle); |
| | | |
| | | } else { |
| | | scoped_lock guard(thread_lock); |
| | |
| | | buffer.data(), buffer.size(), punc_cache, |
| | | is_final, audio_fs, |
| | | wav_format, (ASR_TYPE)asr_mode_, |
| | | hotwords_embedding, itn); |
| | | hotwords_embedding, itn, decoder_handle); |
| | | } else { |
| | | scoped_lock guard(thread_lock); |
| | | msg["access_num"]=(int)msg["access_num"]-1; |
| | |
| | | data_msg->msg["audio_fs"] = 16000; // default is 16k |
| | | data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly |
| | | data_msg->msg["is_eof"]=false; // if this connection is closed |
| | | FUNASR_DEC_HANDLE decoder_handle = |
| | | FunASRWfstDecoderInit(tpass_handle, ASR_TWO_PASS, global_beam_, lattice_beam_, am_scale_); |
| | | data_msg->decoder_handle = decoder_handle; |
| | | data_msg->punc_cache = |
| | | std::make_shared<std::vector<std::vector<std::string>>>(2); |
| | | data_msg->strand_ = std::make_shared<asio::io_context::strand>(io_decoder_); |
| | |
| | | // finished and avoid access freed tpass_online_handle |
| | | unique_lock guard_decoder(*(data_msg->thread_lock)); |
| | | if (data_msg->msg["access_num"]==0 && data_msg->msg["is_eof"]==true) { |
| | | FunWfstDecoderUnloadHwsRes(data_msg->decoder_handle); |
| | | FunASRWfstDecoderUninit(data_msg->decoder_handle); |
| | | data_msg->decoder_handle = nullptr; |
| | | FunTpassOnlineUninit(data_msg->tpass_online_handle); |
| | | data_msg->tpass_online_handle = nullptr; |
| | | data_map.erase(hdl); |
| | |
| | | nn_hotwords += " " + pair.first; |
| | | LOG(INFO) << pair.first << " : " << pair.second; |
| | | } |
| | | // FunWfstDecoderLoadHwsRes(msg_data->decoder_handle, fst_inc_wts_, merged_hws_map); |
| | | FunWfstDecoderLoadHwsRes(msg_data->decoder_handle, fst_inc_wts_, merged_hws_map); |
| | | |
| | | // nn |
| | | std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, nn_hotwords, ASR_TWO_PASS); |
| | |
| | | msg_data->msg["itn"], |
| | | msg_data->msg["audio_fs"], |
| | | msg_data->msg["wav_format"], |
| | | std::ref(msg_data->tpass_online_handle))); |
| | | std::ref(msg_data->tpass_online_handle), |
| | | std::ref(msg_data->decoder_handle))); |
| | | msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1; |
| | | } |
| | | catch (std::exception const &e) |
| | |
| | | msg_data->msg["itn"], |
| | | msg_data->msg["audio_fs"], |
| | | msg_data->msg["wav_format"], |
| | | std::ref(msg_data->tpass_online_handle))); |
| | | std::ref(msg_data->tpass_online_handle), |
| | | std::ref(msg_data->decoder_handle))); |
| | | msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1; |
| | | } |
| | | } |
| | |
| | | FUNASR_HANDLE tpass_online_handle=NULL; |
| | | std::string online_res = ""; |
| | | std::string tpass_res = ""; |
| | | std::shared_ptr<asio::io_context::strand> strand_; // for data execute in order |
| | | std::shared_ptr<asio::io_context::strand> strand_; // for data execute in order |
| | | FUNASR_DEC_HANDLE decoder_handle=NULL; |
| | | } FUNASR_MESSAGE; |
| | | |
| | | // See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about |
| | |
| | | bool itn, |
| | | int audio_fs, |
| | | std::string wav_format, |
| | | FUNASR_HANDLE& tpass_online_handle); |
| | | FUNASR_HANDLE& tpass_online_handle, |
| | | FUNASR_DEC_HANDLE& decoder_handle); |
| | | |
| | | void initAsr(std::map<std::string, std::string>& model_path, int thread_num); |
| | | void on_message(websocketpp::connection_hdl hdl, message_ptr msg); |