| | |
| | | session_options.DisableCpuMemArena(); |
| | | |
| | | try{ |
| | | m_session = std::make_unique<Ort::Session>(env_, punc_model.c_str(), session_options); |
| | | m_session = std::make_unique<Ort::Session>(env_, ORTSTRING(punc_model).c_str(), session_options); |
| | | LOG(INFO) << "Successfully load model from " << punc_model; |
| | | } |
| | | catch (std::exception const &e) { |
| | | LOG(ERROR) << "Error when load punc onnx model: " << e.what(); |
| | | exit(0); |
| | | exit(-1); |
| | | } |
| | | // read inputnames outputnames |
| | | string strName; |
| | |
| | | { |
| | | } |
| | | |
| | | string CTTransformerOnline::AddPunc(const char* sz_input, vector<string> &arr_cache) |
| | | string CTTransformerOnline::AddPunc(const char* sz_input, vector<string> &arr_cache, std::string language) |
| | | { |
| | | string strResult; |
| | | vector<string> strOut; |
| | |
| | | for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN) |
| | | { |
| | | nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size()); |
| | | vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff); |
| | | vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff); |
| | | vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + (TOKEN_LEN - nDiff)); |
| | | vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + (TOKEN_LEN - nDiff)); |
| | | InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs; |
| | | InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr; |
| | | |
| | |
| | | nSentEnd = nLastCommaIndex; |
| | | Punction[nSentEnd] = PERIOD_INDEX; |
| | | } |
| | | RemainStr.assign(InputStr.begin() + nSentEnd + 1, InputStr.end()); |
| | | RemainIDs.assign(InputIDs.begin() + nSentEnd + 1, InputIDs.end()); |
| | | InputStr.assign(InputStr.begin(), InputStr.begin() + nSentEnd + 1); // minit_sentence |
| | | Punction.assign(Punction.begin(), Punction.begin() + nSentEnd + 1); |
| | | RemainStr.assign(InputStr.begin() + (nSentEnd + 1), InputStr.end()); |
| | | RemainIDs.assign(InputIDs.begin() + (nSentEnd + 1), InputIDs.end()); |
| | | InputStr.assign(InputStr.begin(), InputStr.begin() + (nSentEnd + 1)); // minit_sentence |
| | | Punction.assign(Punction.begin(), Punction.begin() + (nSentEnd + 1)); |
| | | } |
| | | |
| | | for (auto& item : Punction) |
| | |
| | | break; |
| | | } |
| | | } |
| | | arr_cache.assign(sentence_words_list.begin() + nSentEnd + 1, sentence_words_list.end()); |
| | | arr_cache.assign(sentence_words_list.begin() + (nSentEnd + 1), sentence_words_list.end()); |
| | | |
| | | if (sentenceOut.size() > 0 && m_tokenizer.IsPunc(sentenceOut[sentenceOut.size() - 1])) |
| | | { |
| | |
| | | text_lengths_dim.size()); //, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32); |
| | | |
| | | //vad_mask |
| | | vector<float> arVadMask,arSubMask; |
| | | // vector<float> arVadMask,arSubMask; |
| | | vector<float> arVadMask; |
| | | int nTextLength = input_data.size(); |
| | | |
| | | VadMask(nTextLength, nCacheSize, arVadMask); |
| | | Triangle(nTextLength, arSubMask); |
| | | // Triangle(nTextLength, arSubMask); |
| | | std::array<int64_t, 4> VadMask_Dim{ 1,1, nTextLength ,nTextLength }; |
| | | Ort::Value onnx_vad_mask = Ort::Value::CreateTensor<float>( |
| | | m_memoryInfo, |
| | |
| | | std::array<int64_t, 4> SubMask_Dim{ 1,1, nTextLength ,nTextLength }; |
| | | Ort::Value onnx_sub_mask = Ort::Value::CreateTensor<float>( |
| | | m_memoryInfo, |
| | | arSubMask.data(), |
| | | arSubMask.size() , |
| | | arVadMask.data(), |
| | | arVadMask.size(), |
| | | SubMask_Dim.data(), |
| | | SubMask_Dim.size()); // , ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); |
| | | |
| | |
| | | catch (std::exception const &e) |
| | | { |
| | | LOG(ERROR) << "Error when run punc onnx forword: " << (e.what()); |
| | | exit(0); |
| | | } |
| | | return punction; |
| | | } |