| | |
| | | session_options.DisableCpuMemArena(); |
| | | |
| | | try{ |
| | | m_session = std::make_unique<Ort::Session>(env_, punc_model.c_str(), session_options); |
| | | m_session = std::make_unique<Ort::Session>(env_, ORTSTRING(punc_model).c_str(), session_options); |
| | | LOG(INFO) << "Successfully load model from " << punc_model; |
| | | } |
| | | catch (std::exception const &e) { |
| | |
| | | { |
| | | } |
| | | |
| | | string CTTransformer::AddPunc(const char* sz_input) |
| | | string CTTransformer::AddPunc(const char* sz_input, std::string language) |
| | | { |
| | | string strResult; |
| | | vector<string> strOut; |
| | |
| | | for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN) |
| | | { |
| | | nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size()); |
| | | vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff); |
| | | vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff); |
| | | vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + (TOKEN_LEN - nDiff)); |
| | | vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + (TOKEN_LEN - nDiff)); |
| | | InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs; |
| | | InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr; |
| | | |
| | |
| | | nSentEnd = nLastCommaIndex; |
| | | Punction[nSentEnd] = PERIOD_INDEX; |
| | | } |
| | | RemainStr.assign(InputStr.begin() + nSentEnd + 1, InputStr.end()); |
| | | RemainIDs.assign(InputIDs.begin() + nSentEnd + 1, InputIDs.end()); |
| | | InputStr.assign(InputStr.begin(), InputStr.begin() + nSentEnd + 1); // minit_sentence |
| | | Punction.assign(Punction.begin(), Punction.begin() + nSentEnd + 1); |
| | | RemainStr.assign(InputStr.begin() + (nSentEnd + 1), InputStr.end()); |
| | | RemainIDs.assign(InputIDs.begin() + (nSentEnd + 1), InputIDs.end()); |
| | | InputStr.assign(InputStr.begin(), InputStr.begin() + (nSentEnd + 1)); // minit_sentence |
| | | Punction.assign(Punction.begin(), Punction.begin() + (nSentEnd + 1)); |
| | | } |
| | | |
| | | NewPunctuation.insert(NewPunctuation.end(), Punction.begin(), Punction.end()); |
| | |
| | | } |
| | | } |
| | | } |
| | | for (auto& item : NewSentenceOut) |
| | | |
| | | for (auto& item : NewSentenceOut){ |
| | | strResult += item; |
| | | } |
| | | |
| | | if(language == "en-bpe"){ |
| | | std::vector<std::string> chineseSymbols; |
| | | chineseSymbols.push_back(","); |
| | | chineseSymbols.push_back("。"); |
| | | chineseSymbols.push_back("、"); |
| | | chineseSymbols.push_back("?"); |
| | | |
| | | std::string englishSymbols = ",.,?"; |
| | | for (size_t i = 0; i < chineseSymbols.size(); i++) { |
| | | size_t pos = 0; |
| | | while ((pos = strResult.find(chineseSymbols[i], pos)) != std::string::npos) { |
| | | strResult.replace(pos, 3, 1, englishSymbols[i]); |
| | | pos++; |
| | | } |
| | | } |
| | | } |
| | | |
| | | return strResult; |
| | | } |
| | | |