add offline punc for onnxruntime
| | |
| | | #define MODEL_SAMPLE_RATE 16000 |
| | | #endif |
| | | |
| | | // vad |
| | | #ifndef VAD_SILENCE_DYRATION |
| | | #define VAD_SILENCE_DYRATION 15000 |
| | | #endif |
| | |
| | | #define VAD_SPEECH_NOISE_THRES 0.9 |
| | | #endif |
| | | |
| | | // punc |
| | | #define PUNC_MODEL_FILE "punc_model.onnx" |
| | | #define PUNC_YAML_FILE "punc.yaml" |
| | | |
| | | #define UNK_CHAR "<unk>" |
| | | |
| | | #define INPUT_NUM 2 |
| | | #define INPUT_NAME1 "input" |
| | | #define INPUT_NAME2 "text_lengths" |
| | | #define OUTPUT_NAME "logits" |
| | | #define TOKEN_LEN 20 |
| | | |
| | | #define CANDIDATE_NUM 6 |
| | | #define UNKNOW_INDEX 0 |
| | | #define NOTPUNC_INDEX 1 |
| | | #define COMMA_INDEX 2 |
| | | #define PERIOD_INDEX 3 |
| | | #define QUESTION_INDEX 4 |
| | | #define DUN_INDEX 5 |
| | | #define CACHE_POP_TRIGGER_LIMIT 200 |
| | | |
| | | #endif |
| | |
| | | virtual std::string forward(float *din, int len, int flag) = 0; |
| | | virtual std::string rescoring() = 0; |
| | | virtual std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data)=0; |
| | | virtual std::string AddPunc(const char* szInput)=0; |
| | | }; |
| | | |
| | | Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false); |
| New file |
| | |
| | | /** |
| | | * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) |
| | | * |
| | | * See LICENSE for clarification regarding multiple authors |
| | | * |
| | | * Licensed under the Apache License, Version 2.0 (the "License"); |
| | | * you may not use this file except in compliance with the License. |
| | | * You may obtain a copy of the License at |
| | | * |
| | | * http://www.apache.org/licenses/LICENSE-2.0 |
| | | * |
| | | * Unless required by applicable law or agreed to in writing, software |
| | | * distributed under the License is distributed on an "AS IS" BASIS, |
| | | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| | | * See the License for the specific language governing permissions and |
| | | * limitations under the License. |
| | | */ |
| | | |
| | | // This file is copied/modified from kaldi/src/feat/feature-fbank.h |
| | | |
| | | #ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ |
| | | #define KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ |
| | | |
| | | #include <map> |
| | | #include <string> |
| | | #include <vector> |
| | | |
| | | #include "kaldi-native-fbank/csrc/feature-window.h" |
| | | #include "kaldi-native-fbank/csrc/mel-computations.h" |
| | | #include "kaldi-native-fbank/csrc/rfft.h" |
| | | |
| | | namespace knf { |
| | | |
| | | struct FbankOptions { |
| | | FrameExtractionOptions frame_opts; |
| | | MelBanksOptions mel_opts; |
| | | // append an extra dimension with energy to the filter banks |
| | | bool use_energy = false; |
| | | float energy_floor = 0.0f; // active iff use_energy==true |
| | | |
| | | // If true, compute log_energy before preemphasis and windowing |
| | | // If false, compute log_energy after preemphasis ans windowing |
| | | bool raw_energy = true; // active iff use_energy==true |
| | | |
| | | // If true, put energy last (if using energy) |
| | | // If false, put energy first |
| | | bool htk_compat = false; // active iff use_energy==true |
| | | |
| | | // if true (default), produce log-filterbank, else linear |
| | | bool use_log_fbank = true; |
| | | |
| | | // if true (default), use power in filterbank |
| | | // analysis, else magnitude. |
| | | bool use_power = true; |
| | | |
| | | FbankOptions() { mel_opts.num_bins = 23; } |
| | | |
| | | std::string ToString() const { |
| | | std::ostringstream os; |
| | | os << "frame_opts: \n"; |
| | | os << frame_opts << "\n"; |
| | | os << "\n"; |
| | | |
| | | os << "mel_opts: \n"; |
| | | os << mel_opts << "\n"; |
| | | |
| | | os << "use_energy: " << use_energy << "\n"; |
| | | os << "energy_floor: " << energy_floor << "\n"; |
| | | os << "raw_energy: " << raw_energy << "\n"; |
| | | os << "htk_compat: " << htk_compat << "\n"; |
| | | os << "use_log_fbank: " << use_log_fbank << "\n"; |
| | | os << "use_power: " << use_power << "\n"; |
| | | return os.str(); |
| | | } |
| | | }; |
| | | |
| | | std::ostream &operator<<(std::ostream &os, const FbankOptions &opts); |
| | | |
| | | class FbankComputer { |
| | | public: |
| | | using Options = FbankOptions; |
| | | |
| | | explicit FbankComputer(const FbankOptions &opts); |
| | | ~FbankComputer(); |
| | | |
| | | int32_t Dim() const { |
| | | return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); |
| | | } |
| | | |
| | | // if true, compute log_energy_pre_window but after dithering and dc removal |
| | | bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } |
| | | |
| | | const FrameExtractionOptions &GetFrameOptions() const { |
| | | return opts_.frame_opts; |
| | | } |
| | | |
| | | const FbankOptions &GetOptions() const { return opts_; } |
| | | |
| | | /** |
| | | Function that computes one frame of features from |
| | | one frame of signal. |
| | | |
| | | @param [in] signal_raw_log_energy The log-energy of the frame of the signal |
| | | prior to windowing and pre-emphasis, or |
| | | log(numeric_limits<float>::min()), whichever is greater. Must be |
| | | ignored by this function if this class returns false from |
| | | this->NeedsRawLogEnergy(). |
| | | @param [in] vtln_warp The VTLN warping factor that the user wants |
| | | to be applied when computing features for this utterance. Will |
| | | normally be 1.0, meaning no warping is to be done. The value will |
| | | be ignored for feature types that don't support VLTN, such as |
| | | spectrogram features. |
| | | @param [in] signal_frame One frame of the signal, |
| | | as extracted using the function ExtractWindow() using the options |
| | | returned by this->GetFrameOptions(). The function will use the |
| | | vector as a workspace, which is why it's a non-const pointer. |
| | | @param [out] feature Pointer to a vector of size this->Dim(), to which |
| | | the computed feature will be written. It should be pre-allocated. |
| | | */ |
| | | void Compute(float signal_raw_log_energy, float vtln_warp, |
| | | std::vector<float> *signal_frame, float *feature); |
| | | |
| | | private: |
| | | const MelBanks *GetMelBanks(float vtln_warp); |
| | | |
| | | FbankOptions opts_; |
| | | float log_energy_floor_; |
| | | std::map<float, MelBanks *> mel_banks_; // float is VTLN coefficient. |
| | | Rfft rfft_; |
| | | }; |
| | | |
| | | } // namespace knf |
| | | |
| | | #endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_ |
| New file |
| | |
| | | /** |
| | | * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) |
| | | * |
| | | * See LICENSE for clarification regarding multiple authors |
| | | * |
| | | * Licensed under the Apache License, Version 2.0 (the "License"); |
| | | * you may not use this file except in compliance with the License. |
| | | * You may obtain a copy of the License at |
| | | * |
| | | * http://www.apache.org/licenses/LICENSE-2.0 |
| | | * |
| | | * Unless required by applicable law or agreed to in writing, software |
| | | * distributed under the License is distributed on an "AS IS" BASIS, |
| | | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| | | * See the License for the specific language governing permissions and |
| | | * limitations under the License. |
| | | */ |
| | | |
| | | // The content in this file is copied/modified from |
| | | // This file is copied/modified from kaldi/src/feat/online-feature.h |
| | | #ifndef KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_ |
| | | #define KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_ |
| | | |
| | | #include <cstdint> |
| | | #include <deque> |
| | | #include <vector> |
| | | |
| | | #include "kaldi-native-fbank/csrc/feature-fbank.h" |
| | | |
| | | namespace knf { |
| | | |
| | | /// This class serves as a storage for feature vectors with an option to limit |
| | | /// the memory usage by removing old elements. The deleted frames indices are |
| | | /// "remembered" so that regardless of the MAX_ITEMS setting, the user always |
| | | /// provides the indices as if no deletion was being performed. |
| | | /// This is useful when processing very long recordings which would otherwise |
| | | /// cause the memory to eventually blow up when the features are not being |
| | | /// removed. |
| | | class RecyclingVector { |
| | | public: |
| | | /// By default it does not remove any elements. |
| | | explicit RecyclingVector(int32_t items_to_hold = -1); |
| | | |
| | | ~RecyclingVector() = default; |
| | | RecyclingVector(const RecyclingVector &) = delete; |
| | | RecyclingVector &operator=(const RecyclingVector &) = delete; |
| | | |
| | | // The pointer is owned by RecyclingVector |
| | | // Users should not free it |
| | | const float *At(int32_t index) const; |
| | | |
| | | void PushBack(std::vector<float> item); |
| | | |
| | | /// This method returns the size as if no "recycling" had happened, |
| | | /// i.e. equivalent to the number of times the PushBack method has been |
| | | /// called. |
| | | int32_t Size() const; |
| | | |
| | | private: |
| | | std::deque<std::vector<float>> items_; |
| | | int32_t items_to_hold_; |
| | | int32_t first_available_index_; |
| | | }; |
| | | |
| | | /// This is a templated class for online feature extraction; |
| | | /// it's templated on a class like MfccComputer or PlpComputer |
| | | /// that does the basic feature extraction. |
| | | template <class C> |
| | | class OnlineGenericBaseFeature { |
| | | public: |
| | | // Constructor from options class |
| | | explicit OnlineGenericBaseFeature(const typename C::Options &opts); |
| | | |
| | | int32_t Dim() const { return computer_.Dim(); } |
| | | |
| | | float FrameShiftInSeconds() const { |
| | | return computer_.GetFrameOptions().frame_shift_ms / 1000.0f; |
| | | } |
| | | |
| | | int32_t NumFramesReady() const { return features_.Size(); } |
| | | |
| | | // Note: IsLastFrame() will only ever return true if you have called |
| | | // InputFinished() (and this frame is the last frame). |
| | | bool IsLastFrame(int32_t frame) const { |
| | | return input_finished_ && frame == NumFramesReady() - 1; |
| | | } |
| | | |
| | | const float *GetFrame(int32_t frame) const { return features_.At(frame); } |
| | | |
| | | // This would be called from the application, when you get |
| | | // more wave data. Note: the sampling_rate is only provided so |
| | | // the code can assert that it matches the sampling rate |
| | | // expected in the options. |
| | | // |
| | | // @param sampling_rate The sampling_rate of the input waveform |
| | | // @param waveform Pointer to a 1-D array of size n |
| | | // @param n Number of entries in waveform |
| | | void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); |
| | | |
| | | // InputFinished() tells the class you won't be providing any |
| | | // more waveform. This will help flush out the last frame or two |
| | | // of features, in the case where snip-edges == false; it also |
| | | // affects the return value of IsLastFrame(). |
| | | void InputFinished(); |
| | | |
| | | private: |
| | | // This function computes any additional feature frames that it is possible to |
| | | // compute from 'waveform_remainder_', which at this point may contain more |
| | | // than just a remainder-sized quantity (because AcceptWaveform() appends to |
| | | // waveform_remainder_ before calling this function). It adds these feature |
| | | // frames to features_, and shifts off any now-unneeded samples of input from |
| | | // waveform_remainder_ while incrementing waveform_offset_ by the same amount. |
| | | void ComputeFeatures(); |
| | | |
| | | C computer_; // class that does the MFCC or PLP or filterbank computation |
| | | |
| | | FeatureWindowFunction window_function_; |
| | | |
| | | // features_ is the Mfcc or Plp or Fbank features that we have already |
| | | // computed. |
| | | |
| | | RecyclingVector features_; |
| | | |
| | | // True if the user has called "InputFinished()" |
| | | bool input_finished_; |
| | | |
| | | // waveform_offset_ is the number of samples of waveform that we have |
| | | // already discarded, i.e. that were prior to 'waveform_remainder_'. |
| | | int64_t waveform_offset_; |
| | | |
| | | // waveform_remainder_ is a short piece of waveform that we may need to keep |
| | | // after extracting all the whole frames we can (whatever length of feature |
| | | // will be required for the next phase of computation). |
| | | // It is a 1-D tensor |
| | | std::vector<float> waveform_remainder_; |
| | | }; |
| | | |
| | | using OnlineFbank = OnlineGenericBaseFeature<FbankComputer>; |
| | | |
| | | } // namespace knf |
| | | |
| | | #endif // KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_ |
| | |
| | | #pragma once |
| | | |
| | | #include <algorithm> |
| | | typedef struct |
| | | { |
| | | std::string msg; |
| | |
| | | } |
| | | } |
| | | } |
| | | |
| | | template <class ForwardIterator> |
| | | inline static size_t argmax(ForwardIterator first, ForwardIterator last) { |
| | | return std::distance(first, std::max_element(first, last)); |
| | | } |
| | |
| | | if (fnCallback) |
| | | fnCallback(nStep, nTotal); |
| | | } |
| | | if(true){ |
| | | string punc_res = pRecogObj->AddPunc((pResult->msg).c_str()); |
| | | pResult->msg = punc_res; |
| | | } |
| | | |
| | | return pResult; |
| | | } |
| | |
| | | vadHandle->init_vad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES); |
| | | } |
| | | |
| | | // PUNC model |
| | | if(true){ |
| | | puncHandle = make_unique<CTTransformer>(path, nNumThread); |
| | | } |
| | | |
| | | if(quantize) |
| | | { |
| | | model_path = pathAppend(path, "model_quant.onnx"); |
| | |
| | | m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions); |
| | | #endif |
| | | |
| | | vector<string> m_strInputNames, m_strOutputNames; |
| | | string strName; |
| | | getInputName(m_session.get(), strName); |
| | | m_strInputNames.push_back(strName.c_str()); |
| | |
| | | |
| | | vector<std::vector<int>> ModelImp::vad_seg(std::vector<float>& pcm_data){ |
| | | return vadHandle->infer(pcm_data); |
| | | } |
| | | |
| | | string ModelImp::AddPunc(const char* szInput){ |
| | | return puncHandle->AddPunc(szInput); |
| | | } |
| | | |
| | | vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) { |
| | |
| | | auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>(); |
| | | result = greedy_search(floatData, *encoder_out_lens); |
| | | } |
| | | catch (...) |
| | | catch (std::exception const &e) |
| | | { |
| | | result = ""; |
| | | printf(e.what()); |
| | | } |
| | | |
| | | return result; |
| | |
| | | knf::FbankOptions fbank_opts; |
| | | |
| | | std::unique_ptr<FsmnVad> vadHandle; |
| | | std::unique_ptr<CTTransformer> puncHandle; |
| | | |
| | | Vocab* vocab; |
| | | vector<float> means_list; |
| | |
| | | Ort::Env env_; |
| | | Ort::SessionOptions sessionOptions; |
| | | |
| | | vector<string> m_strInputNames, m_strOutputNames; |
| | | vector<const char*> m_szInputNames; |
| | | vector<const char*> m_szOutputNames; |
| | | |
| | |
| | | string forward(float* din, int len, int flag); |
| | | string rescoring(); |
| | | std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data); |
| | | string AddPunc(const char* szInput); |
| | | |
| | | }; |
| | | |
| | |
| | | #include "ComDefine.h" |
| | | #include "commonfunc.h" |
| | | #include "predefine_coe.h" |
| | | #include "tokenizer.h" |
| | | #include "punc_infer.h" |
| | | #include "FsmnVad.h" |
| | | #include "e2e_vad.h" |
| | | #include "Vocab.h" |
| New file |
| | |
| | | #include "precomp.h" |
| | | |
| | | CTTransformer::CTTransformer(const char* sz_model_dir, int thread_num) |
| | | :env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options{} |
| | | { |
| | | session_options.SetIntraOpNumThreads(thread_num); |
| | | session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL); |
| | | session_options.DisableCpuMemArena(); |
| | | |
| | | string strModelPath = pathAppend(sz_model_dir, PUNC_MODEL_FILE); |
| | | string strYamlPath = pathAppend(sz_model_dir, PUNC_YAML_FILE); |
| | | |
| | | #ifdef _WIN32 |
| | | std::wstring detPath = strToWstr(strModelPath); |
| | | m_session = std::make_unique<Ort::Session>(env_, detPath.c_str(), session_options); |
| | | #else |
| | | m_session = std::make_unique<Ort::Session>(env_, strModelPath.c_str(), session_options); |
| | | #endif |
| | | // read inputnames outputnames |
| | | vector<string> m_strInputNames, m_strOutputNames; |
| | | string strName; |
| | | getInputName(m_session.get(), strName); |
| | | m_strInputNames.push_back(strName.c_str()); |
| | | getInputName(m_session.get(), strName, 1); |
| | | m_strInputNames.push_back(strName); |
| | | |
| | | getOutputName(m_session.get(), strName); |
| | | m_strOutputNames.push_back(strName); |
| | | |
| | | for (auto& item : m_strInputNames) |
| | | m_szInputNames.push_back(item.c_str()); |
| | | for (auto& item : m_strOutputNames) |
| | | m_szOutputNames.push_back(item.c_str()); |
| | | |
| | | m_tokenizer.OpenYaml(strYamlPath.c_str()); |
| | | } |
| | | |
| | | CTTransformer::~CTTransformer() |
| | | { |
| | | } |
| | | |
| | | string CTTransformer::AddPunc(const char* sz_input) |
| | | { |
| | | string strResult; |
| | | vector<string> strOut; |
| | | vector<int> InputData; |
| | | m_tokenizer.Tokenize(sz_input, strOut, InputData); |
| | | |
| | | int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN); |
| | | int nCurBatch = -1; |
| | | int nSentEnd = -1, nLastCommaIndex = -1; |
| | | vector<int64_t> RemainIDs; // |
| | | vector<string> RemainStr; // |
| | | vector<int> NewPunctuation; // |
| | | vector<string> NewString; // |
| | | vector<string> NewSentenceOut; |
| | | vector<int> NewPuncOut; |
| | | int nDiff = 0; |
| | | for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN) |
| | | { |
| | | nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size()); |
| | | vector<int64_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff); |
| | | vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff); |
| | | InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs; |
| | | InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr; |
| | | |
| | | auto Punction = Infer(InputIDs); |
| | | nCurBatch = i / TOKEN_LEN; |
| | | if (nCurBatch < nTotalBatch - 1) // not the last minisetence |
| | | { |
| | | nSentEnd = -1; |
| | | nLastCommaIndex = -1; |
| | | for (int nIndex = Punction.size() - 2; nIndex > 0; nIndex--) |
| | | { |
| | | if (m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(PERIOD_INDEX) || m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(QUESTION_INDEX)) |
| | | { |
| | | nSentEnd = nIndex; |
| | | break; |
| | | } |
| | | if (nLastCommaIndex < 0 && m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(COMMA_INDEX)) |
| | | { |
| | | nLastCommaIndex = nIndex; |
| | | } |
| | | } |
| | | if (nSentEnd < 0 && InputStr.size() > CACHE_POP_TRIGGER_LIMIT && nLastCommaIndex > 0) |
| | | { |
| | | nSentEnd = nLastCommaIndex; |
| | | Punction[nSentEnd] = PERIOD_INDEX; |
| | | } |
| | | RemainStr.assign(InputStr.begin() + nSentEnd + 1, InputStr.end()); |
| | | RemainIDs.assign(InputIDs.begin() + nSentEnd + 1, InputIDs.end()); |
| | | InputStr.assign(InputStr.begin(), InputStr.begin() + nSentEnd + 1); // minit_sentence |
| | | Punction.assign(Punction.begin(), Punction.begin() + nSentEnd + 1); |
| | | } |
| | | |
| | | NewPunctuation.insert(NewPunctuation.end(), Punction.begin(), Punction.end()); |
| | | vector<string> WordWithPunc; |
| | | for (int i = 0; i < InputStr.size(); i++) |
| | | { |
| | | if (i > 0 && !(InputStr[i][0] & 0x80) && (i + 1) <InputStr.size() && !(InputStr[i+1][0] & 0x80))// �м��Ӣ�ģ� |
| | | { |
| | | InputStr[i] = InputStr[i]+ " "; |
| | | } |
| | | WordWithPunc.push_back(InputStr[i]); |
| | | |
| | | if (Punction[i] != NOTPUNC_INDEX) // �»��� |
| | | { |
| | | WordWithPunc.push_back(m_tokenizer.ID2Punc(Punction[i])); |
| | | } |
| | | } |
| | | |
| | | NewString.insert(NewString.end(), WordWithPunc.begin(), WordWithPunc.end()); // new_mini_sentence += "".join(words_with_punc) |
| | | NewSentenceOut = NewString; |
| | | NewPuncOut = NewPunctuation; |
| | | // last mini sentence |
| | | if(nCurBatch == nTotalBatch - 1) |
| | | { |
| | | if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(COMMA_INDEX) || NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(DUN_INDEX)) |
| | | { |
| | | NewSentenceOut.assign(NewString.begin(), NewString.end() - 1); |
| | | NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX)); |
| | | NewPuncOut.assign(NewPunctuation.begin(), NewPunctuation.end() - 1); |
| | | NewPuncOut.push_back(PERIOD_INDEX); |
| | | } |
| | | else if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(PERIOD_INDEX) && NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(QUESTION_INDEX)) |
| | | { |
| | | NewSentenceOut = NewString; |
| | | NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX)); |
| | | NewPuncOut = NewPunctuation; |
| | | NewPuncOut.push_back(PERIOD_INDEX); |
| | | } |
| | | } |
| | | } |
| | | for (auto& item : NewSentenceOut) |
| | | strResult += item; |
| | | return strResult; |
| | | } |
| | | |
| | | vector<int> CTTransformer::Infer(vector<int64_t> input_data) |
| | | { |
| | | Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); |
| | | vector<int> punction; |
| | | std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()}; |
| | | Ort::Value onnx_input = Ort::Value::CreateTensor<int64_t>(m_memoryInfo, |
| | | input_data.data(), |
| | | input_data.size(), |
| | | input_shape_.data(), |
| | | input_shape_.size()); |
| | | |
| | | std::array<int32_t,1> text_lengths{ (int32_t)input_data.size() }; |
| | | std::array<int64_t,1> text_lengths_dim{ 1 }; |
| | | Ort::Value onnx_text_lengths = Ort::Value::CreateTensor( |
| | | m_memoryInfo, |
| | | text_lengths.data(), |
| | | text_lengths.size() * sizeof(int32_t), |
| | | text_lengths_dim.data(), |
| | | text_lengths_dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32); |
| | | std::vector<Ort::Value> input_onnx; |
| | | input_onnx.emplace_back(std::move(onnx_input)); |
| | | input_onnx.emplace_back(std::move(onnx_text_lengths)); |
| | | |
| | | try { |
| | | auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size()); |
| | | std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape(); |
| | | |
| | | int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>()); |
| | | float * floatData = outputTensor[0].GetTensorMutableData<float>(); |
| | | |
| | | for (int i = 0; i < outputCount; i += CANDIDATE_NUM) |
| | | { |
| | | int index = argmax(floatData + i, floatData + i + CANDIDATE_NUM-1); |
| | | punction.push_back(index); |
| | | } |
| | | } |
| | | catch (std::exception const &e) |
| | | { |
| | | printf(e.what()); |
| | | } |
| | | return punction; |
| | | } |
| | | |
| | | |
| | | |
| New file |
| | |
| | | #pragma once |
| | | |
| | | class CTTransformer { |
| | | /** |
| | | * Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | * CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection |
| | | * https://arxiv.org/pdf/2003.01309.pdf |
| | | */ |
| | | |
| | | private: |
| | | |
| | | CTokenizer m_tokenizer; |
| | | vector<const char*> m_szInputNames; |
| | | vector<const char*> m_szOutputNames; |
| | | |
| | | std::shared_ptr<Ort::Session> m_session; |
| | | Ort::Env env_; |
| | | Ort::SessionOptions session_options; |
| | | public: |
| | | |
| | | CTTransformer(const char* sz_model_dir, int thread_num); |
| | | ~CTTransformer(); |
| | | vector<int> Infer(vector<int64_t> input_data); |
| | | string AddPunc(const char* sz_input); |
| | | }; |
| New file |
| | |
| | | #include "precomp.h" |
| | | |
| | | CTokenizer::CTokenizer(const char* szYmlFile):m_Ready(false) |
| | | { |
| | | OpenYaml(szYmlFile); |
| | | } |
| | | |
| | | CTokenizer::CTokenizer():m_Ready(false) |
| | | { |
| | | } |
| | | |
| | | void CTokenizer::read_yml(const YAML::Node& node) |
| | | { |
| | | if (node.IsMap()) |
| | | {//��map�� |
| | | for (auto it = node.begin(); it != node.end(); ++it) |
| | | { |
| | | read_yml(it->second); |
| | | } |
| | | } |
| | | if (node.IsSequence()) {//�������� |
| | | for (size_t i = 0; i < node.size(); ++i) { |
| | | read_yml(node[i]); |
| | | } |
| | | } |
| | | if (node.IsScalar()) {//�DZ����� |
| | | cout << node.as<string>() << endl; |
| | | } |
| | | } |
| | | |
| | | bool CTokenizer::OpenYaml(const char* szYmlFile) |
| | | { |
| | | YAML::Node m_Config = YAML::LoadFile(szYmlFile); |
| | | if (m_Config.IsNull()) |
| | | return false; |
| | | try |
| | | { |
| | | auto Tokens = m_Config["token_list"]; |
| | | if (Tokens.IsSequence()) |
| | | { |
| | | for (size_t i = 0; i < Tokens.size(); ++i) |
| | | { |
| | | if (Tokens[i].IsScalar()) |
| | | { |
| | | m_ID2Token.push_back(Tokens[i].as<string>()); |
| | | m_Token2ID.insert(make_pair<string, int>(Tokens[i].as<string>(), i)); |
| | | } |
| | | } |
| | | } |
| | | auto Puncs = m_Config["punc_list"]; |
| | | if (Puncs.IsSequence()) |
| | | { |
| | | for (size_t i = 0; i < Puncs.size(); ++i) |
| | | { |
| | | if (Puncs[i].IsScalar()) |
| | | { |
| | | m_ID2Punc.push_back(Puncs[i].as<string>()); |
| | | m_Punc2ID.insert(make_pair<string, int>(Puncs[i].as<string>(), i)); |
| | | } |
| | | } |
| | | } |
| | | } |
| | | catch (YAML::BadFile& e) { |
| | | std::cout << "read error!" << std::endl; |
| | | return false; |
| | | } |
| | | m_Ready = true; |
| | | return m_Ready; |
| | | } |
| | | |
| | | vector<string> CTokenizer::ID2String(vector<int> Input) |
| | | { |
| | | vector<string> result; |
| | | for (auto& item : Input) |
| | | { |
| | | result.push_back(m_ID2Token[item]); |
| | | } |
| | | return result; |
| | | } |
| | | |
| | | int CTokenizer::String2ID(string Input) |
| | | { |
| | | int nID = 0; // <blank> |
| | | if (m_Token2ID.find(Input) != m_Token2ID.end()) |
| | | nID=(m_Token2ID[Input]); |
| | | else |
| | | nID=(m_Token2ID[UNK_CHAR]); |
| | | return nID; |
| | | } |
| | | |
| | | vector<int> CTokenizer::String2IDs(vector<string> Input) |
| | | { |
| | | vector<int> result; |
| | | for (auto& item : Input) |
| | | { |
| | | transform(item.begin(), item.end(), item.begin(), ::tolower); |
| | | if (m_Token2ID.find(item) != m_Token2ID.end()) |
| | | result.push_back(m_Token2ID[item]); |
| | | else |
| | | result.push_back(m_Token2ID[UNK_CHAR]); |
| | | } |
| | | return result; |
| | | } |
| | | |
| | | vector<string> CTokenizer::ID2Punc(vector<int> Input) |
| | | { |
| | | vector<string> result; |
| | | for (auto& item : Input) |
| | | { |
| | | result.push_back(m_ID2Punc[item]); |
| | | } |
| | | return result; |
| | | } |
| | | |
| | | string CTokenizer::ID2Punc(int nPuncID) |
| | | { |
| | | return m_ID2Punc[nPuncID]; |
| | | } |
| | | |
| | | vector<int> CTokenizer::Punc2IDs(vector<string> Input) |
| | | { |
| | | vector<int> result; |
| | | for (auto& item : Input) |
| | | { |
| | | result.push_back(m_Punc2ID[item]); |
| | | } |
| | | return result; |
| | | } |
| | | |
| | | vector<string> CTokenizer::SplitChineseString(const string & strInfo) |
| | | { |
| | | vector<string> list; |
| | | int strSize = strInfo.size(); |
| | | int i = 0; |
| | | |
| | | while (i < strSize) { |
| | | int len = 1; |
| | | for (int j = 0; j < 6 && (strInfo[i] & (0x80 >> j)); j++) { |
| | | len = j + 1; |
| | | } |
| | | list.push_back(strInfo.substr(i, len)); |
| | | i += len; |
| | | } |
| | | return list; |
| | | } |
| | | |
| | | void CTokenizer::strSplit(const string& str, const char split, vector<string>& res) |
| | | { |
| | | if (str == "") |
| | | { |
| | | return; |
| | | } |
| | | string&& strs = str + split; |
| | | size_t pos = strs.find(split); |
| | | |
| | | while (pos != string::npos) |
| | | { |
| | | res.emplace_back(strs.substr(0, pos)); |
| | | strs = move(strs.substr(pos + 1, strs.size())); |
| | | pos = strs.find(split); |
| | | } |
| | | } |
| | | |
| | | void CTokenizer::Tokenize(const char* strInfo, vector<string> & strOut, vector<int> & IDOut) |
| | | { |
| | | vector<string> strList; |
| | | strSplit(strInfo,' ', strList); |
| | | string current_eng,current_chinese; |
| | | for (auto& item : strList) |
| | | { |
| | | current_eng = ""; |
| | | current_chinese = ""; |
| | | for (auto& ch : item) |
| | | { |
| | | if (!(ch& 0x80)) |
| | | { // Ӣ�� |
| | | if (current_chinese.size() > 0) |
| | | { |
| | | // for utf-8 chinese |
| | | auto chineseList = SplitChineseString(current_chinese); |
| | | strOut.insert(strOut.end(), chineseList.begin(),chineseList.end()); |
| | | current_chinese = ""; |
| | | } |
| | | current_eng += ch; |
| | | } |
| | | else |
| | | { |
| | | if (current_eng.size() > 0) |
| | | { |
| | | strOut.push_back(current_eng); |
| | | current_eng = ""; |
| | | } |
| | | current_chinese += ch; |
| | | } |
| | | } |
| | | if (current_chinese.size() > 0) |
| | | { |
| | | auto chineseList = SplitChineseString(current_chinese); |
| | | strOut.insert(strOut.end(), chineseList.begin(), chineseList.end()); |
| | | current_chinese = ""; |
| | | } |
| | | if (current_eng.size() > 0) |
| | | { |
| | | strOut.push_back(current_eng); |
| | | } |
| | | } |
| | | IDOut= String2IDs(strOut); |
| | | } |
| New file |
| | |
| | | #pragma once |
| | | #include "yaml-cpp/yaml.h" |
| | | |
| | | class CTokenizer { |
| | | private: |
| | | |
| | | bool m_Ready = false; |
| | | vector<string> m_ID2Token,m_ID2Punc; |
| | | map<string, int> m_Token2ID,m_Punc2ID; |
| | | |
| | | public: |
| | | |
| | | CTokenizer(const char* szYmlFile); |
| | | CTokenizer(); |
| | | bool OpenYaml(const char* szYmlFile); |
| | | void read_yml(const YAML::Node& node); |
| | | vector<string> ID2String(vector<int> Input); |
| | | vector<int> String2IDs(vector<string> Input); |
| | | int String2ID(string Input); |
| | | vector<string> ID2Punc(vector<int> Input); |
| | | string ID2Punc(int nPuncID); |
| | | vector<int> Punc2IDs(vector<string> Input); |
| | | vector<string> SplitChineseString(const string& strInfo); |
| | | void strSplit(const string& str, const char split, vector<string>& res); |
| | | void Tokenize(const char* strInfo, vector<string>& strOut, vector<int>& IDOut); |
| | | |
| | | }; |