From 55708e7cebaedefc5f69d61f157993da41848b8f Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期日, 23 四月 2023 19:06:25 +0800
Subject: [PATCH] add offline punc for onnxruntime
---
funasr/runtime/onnxruntime/include/Model.h | 1
funasr/runtime/onnxruntime/src/punc_infer.h | 25 +
funasr/runtime/onnxruntime/src/tokenizer.cpp | 208 ++++++++++++++++
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp | 14
funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/online-feature.h | 142 ++++++++++
funasr/runtime/onnxruntime/src/commonfunc.h | 7
funasr/runtime/onnxruntime/src/punc_infer.cpp | 183 ++++++++++++++
funasr/runtime/onnxruntime/src/tokenizer.h | 27 ++
funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/feature-fbank.h | 134 ++++++++++
funasr/runtime/onnxruntime/src/libfunasrapi.cpp | 4
funasr/runtime/onnxruntime/include/ComDefine.h | 21 +
funasr/runtime/onnxruntime/src/paraformer_onnx.h | 3
funasr/runtime/onnxruntime/src/precomp.h | 2
13 files changed, 767 insertions(+), 4 deletions(-)
diff --git a/funasr/runtime/onnxruntime/include/ComDefine.h b/funasr/runtime/onnxruntime/include/ComDefine.h
index 6929e49..72a843d 100644
--- a/funasr/runtime/onnxruntime/include/ComDefine.h
+++ b/funasr/runtime/onnxruntime/include/ComDefine.h
@@ -12,6 +12,7 @@
#define MODEL_SAMPLE_RATE 16000
#endif
+// vad
#ifndef VAD_SILENCE_DYRATION
#define VAD_SILENCE_DYRATION 15000
#endif
@@ -24,5 +25,25 @@
#define VAD_SPEECH_NOISE_THRES 0.9
#endif
+// punc
+#define PUNC_MODEL_FILE "punc_model.onnx"
+#define PUNC_YAML_FILE "punc.yaml"
+
+#define UNK_CHAR "<unk>"
+
+#define INPUT_NUM 2
+#define INPUT_NAME1 "input"
+#define INPUT_NAME2 "text_lengths"
+#define OUTPUT_NAME "logits"
+#define TOKEN_LEN 20
+
+#define CANDIDATE_NUM 6
+#define UNKNOW_INDEX 0
+#define NOTPUNC_INDEX 1
+#define COMMA_INDEX 2
+#define PERIOD_INDEX 3
+#define QUESTION_INDEX 4
+#define DUN_INDEX 5
+#define CACHE_POP_TRIGGER_LIMIT 200
#endif
diff --git a/funasr/runtime/onnxruntime/include/Model.h b/funasr/runtime/onnxruntime/include/Model.h
index cd3b0a3..f92789f 100644
--- a/funasr/runtime/onnxruntime/include/Model.h
+++ b/funasr/runtime/onnxruntime/include/Model.h
@@ -12,6 +12,7 @@
virtual std::string forward(float *din, int len, int flag) = 0;
virtual std::string rescoring() = 0;
virtual std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data)=0;
+ virtual std::string AddPunc(const char* szInput)=0;
};
Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false);
diff --git a/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/feature-fbank.h b/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/feature-fbank.h
new file mode 100644
index 0000000..0786aad
--- /dev/null
+++ b/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/feature-fbank.h
@@ -0,0 +1,134 @@
+/**
+ * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This file is copied/modified from kaldi/src/feat/feature-fbank.h
+
+#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
+#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "kaldi-native-fbank/csrc/feature-window.h"
+#include "kaldi-native-fbank/csrc/mel-computations.h"
+#include "kaldi-native-fbank/csrc/rfft.h"
+
+namespace knf {
+
+struct FbankOptions {
+ FrameExtractionOptions frame_opts;
+ MelBanksOptions mel_opts;
+ // append an extra dimension with energy to the filter banks
+ bool use_energy = false;
+ float energy_floor = 0.0f; // active iff use_energy==true
+
+ // If true, compute log_energy before preemphasis and windowing
+ // If false, compute log_energy after preemphasis ans windowing
+ bool raw_energy = true; // active iff use_energy==true
+
+ // If true, put energy last (if using energy)
+ // If false, put energy first
+ bool htk_compat = false; // active iff use_energy==true
+
+ // if true (default), produce log-filterbank, else linear
+ bool use_log_fbank = true;
+
+ // if true (default), use power in filterbank
+ // analysis, else magnitude.
+ bool use_power = true;
+
+ FbankOptions() { mel_opts.num_bins = 23; }
+
+ std::string ToString() const {
+ std::ostringstream os;
+ os << "frame_opts: \n";
+ os << frame_opts << "\n";
+ os << "\n";
+
+ os << "mel_opts: \n";
+ os << mel_opts << "\n";
+
+ os << "use_energy: " << use_energy << "\n";
+ os << "energy_floor: " << energy_floor << "\n";
+ os << "raw_energy: " << raw_energy << "\n";
+ os << "htk_compat: " << htk_compat << "\n";
+ os << "use_log_fbank: " << use_log_fbank << "\n";
+ os << "use_power: " << use_power << "\n";
+ return os.str();
+ }
+};
+
+std::ostream &operator<<(std::ostream &os, const FbankOptions &opts);
+
+class FbankComputer {
+ public:
+ using Options = FbankOptions;
+
+ explicit FbankComputer(const FbankOptions &opts);
+ ~FbankComputer();
+
+ int32_t Dim() const {
+ return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
+ }
+
+ // if true, compute log_energy_pre_window but after dithering and dc removal
+ bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
+
+ const FrameExtractionOptions &GetFrameOptions() const {
+ return opts_.frame_opts;
+ }
+
+ const FbankOptions &GetOptions() const { return opts_; }
+
+ /**
+ Function that computes one frame of features from
+ one frame of signal.
+
+ @param [in] signal_raw_log_energy The log-energy of the frame of the signal
+ prior to windowing and pre-emphasis, or
+ log(numeric_limits<float>::min()), whichever is greater. Must be
+ ignored by this function if this class returns false from
+ this->NeedsRawLogEnergy().
+ @param [in] vtln_warp The VTLN warping factor that the user wants
+ to be applied when computing features for this utterance. Will
+ normally be 1.0, meaning no warping is to be done. The value will
+ be ignored for feature types that don't support VLTN, such as
+ spectrogram features.
+ @param [in] signal_frame One frame of the signal,
+ as extracted using the function ExtractWindow() using the options
+ returned by this->GetFrameOptions(). The function will use the
+ vector as a workspace, which is why it's a non-const pointer.
+ @param [out] feature Pointer to a vector of size this->Dim(), to which
+ the computed feature will be written. It should be pre-allocated.
+ */
+ void Compute(float signal_raw_log_energy, float vtln_warp,
+ std::vector<float> *signal_frame, float *feature);
+
+ private:
+ const MelBanks *GetMelBanks(float vtln_warp);
+
+ FbankOptions opts_;
+ float log_energy_floor_;
+ std::map<float, MelBanks *> mel_banks_; // float is VTLN coefficient.
+ Rfft rfft_;
+};
+
+} // namespace knf
+
+#endif // KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
diff --git a/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/online-feature.h b/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/online-feature.h
new file mode 100644
index 0000000..5ca5511
--- /dev/null
+++ b/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/online-feature.h
@@ -0,0 +1,142 @@
+/**
+ * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// The content in this file is copied/modified from
+// This file is copied/modified from kaldi/src/feat/online-feature.h
+#ifndef KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_
+#define KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_
+
+#include <cstdint>
+#include <deque>
+#include <vector>
+
+#include "kaldi-native-fbank/csrc/feature-fbank.h"
+
+namespace knf {
+
+/// This class serves as a storage for feature vectors with an option to limit
+/// the memory usage by removing old elements. The deleted frames indices are
+/// "remembered" so that regardless of the MAX_ITEMS setting, the user always
+/// provides the indices as if no deletion was being performed.
+/// This is useful when processing very long recordings which would otherwise
+/// cause the memory to eventually blow up when the features are not being
+/// removed.
+class RecyclingVector {
+ public:
+ /// By default it does not remove any elements.
+ explicit RecyclingVector(int32_t items_to_hold = -1);
+
+ ~RecyclingVector() = default;
+ RecyclingVector(const RecyclingVector &) = delete;
+ RecyclingVector &operator=(const RecyclingVector &) = delete;
+
+ // The pointer is owned by RecyclingVector
+ // Users should not free it
+ const float *At(int32_t index) const;
+
+ void PushBack(std::vector<float> item);
+
+ /// This method returns the size as if no "recycling" had happened,
+ /// i.e. equivalent to the number of times the PushBack method has been
+ /// called.
+ int32_t Size() const;
+
+ private:
+ std::deque<std::vector<float>> items_;
+ int32_t items_to_hold_;
+ int32_t first_available_index_;
+};
+
+/// This is a templated class for online feature extraction;
+/// it's templated on a class like MfccComputer or PlpComputer
+/// that does the basic feature extraction.
+template <class C>
+class OnlineGenericBaseFeature {
+ public:
+ // Constructor from options class
+ explicit OnlineGenericBaseFeature(const typename C::Options &opts);
+
+ int32_t Dim() const { return computer_.Dim(); }
+
+ float FrameShiftInSeconds() const {
+ return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
+ }
+
+ int32_t NumFramesReady() const { return features_.Size(); }
+
+ // Note: IsLastFrame() will only ever return true if you have called
+ // InputFinished() (and this frame is the last frame).
+ bool IsLastFrame(int32_t frame) const {
+ return input_finished_ && frame == NumFramesReady() - 1;
+ }
+
+ const float *GetFrame(int32_t frame) const { return features_.At(frame); }
+
+ // This would be called from the application, when you get
+ // more wave data. Note: the sampling_rate is only provided so
+ // the code can assert that it matches the sampling rate
+ // expected in the options.
+ //
+ // @param sampling_rate The sampling_rate of the input waveform
+ // @param waveform Pointer to a 1-D array of size n
+ // @param n Number of entries in waveform
+ void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
+
+ // InputFinished() tells the class you won't be providing any
+ // more waveform. This will help flush out the last frame or two
+ // of features, in the case where snip-edges == false; it also
+ // affects the return value of IsLastFrame().
+ void InputFinished();
+
+ private:
+ // This function computes any additional feature frames that it is possible to
+ // compute from 'waveform_remainder_', which at this point may contain more
+ // than just a remainder-sized quantity (because AcceptWaveform() appends to
+ // waveform_remainder_ before calling this function). It adds these feature
+ // frames to features_, and shifts off any now-unneeded samples of input from
+ // waveform_remainder_ while incrementing waveform_offset_ by the same amount.
+ void ComputeFeatures();
+
+ C computer_; // class that does the MFCC or PLP or filterbank computation
+
+ FeatureWindowFunction window_function_;
+
+ // features_ is the Mfcc or Plp or Fbank features that we have already
+ // computed.
+
+ RecyclingVector features_;
+
+ // True if the user has called "InputFinished()"
+ bool input_finished_;
+
+ // waveform_offset_ is the number of samples of waveform that we have
+ // already discarded, i.e. that were prior to 'waveform_remainder_'.
+ int64_t waveform_offset_;
+
+ // waveform_remainder_ is a short piece of waveform that we may need to keep
+ // after extracting all the whole frames we can (whatever length of feature
+ // will be required for the next phase of computation).
+ // It is a 1-D tensor
+ std::vector<float> waveform_remainder_;
+};
+
+using OnlineFbank = OnlineGenericBaseFeature<FbankComputer>;
+
+} // namespace knf
+
+#endif // KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_
diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h
index 8d1a97c..cae1bd7 100644
--- a/funasr/runtime/onnxruntime/src/commonfunc.h
+++ b/funasr/runtime/onnxruntime/src/commonfunc.h
@@ -1,5 +1,5 @@
#pragma once
-
+#include <algorithm>
typedef struct
{
std::string msg;
@@ -49,3 +49,8 @@
}
}
}
+
+template <class ForwardIterator>
+inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
+ return std::distance(first, std::max_element(first, last));
+}
diff --git a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
index f15e86f..0adef89 100644
--- a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
+++ b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -133,6 +133,10 @@
if (fnCallback)
fnCallback(nStep, nTotal);
}
+ if(true){
+ string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
+ pResult->msg = punc_res;
+ }
return pResult;
}
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index 1a86da6..69d1554 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -17,6 +17,11 @@
vadHandle->init_vad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
}
+ // PUNC model
+ if(true){
+ puncHandle = make_unique<CTTransformer>(path, nNumThread);
+ }
+
if(quantize)
{
model_path = pathAppend(path, "model_quant.onnx");
@@ -50,6 +55,7 @@
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#endif
+ vector<string> m_strInputNames, m_strOutputNames;
string strName;
getInputName(m_session.get(), strName);
m_strInputNames.push_back(strName.c_str());
@@ -81,6 +87,10 @@
vector<std::vector<int>> ModelImp::vad_seg(std::vector<float>& pcm_data){
return vadHandle->infer(pcm_data);
+}
+
+string ModelImp::AddPunc(const char* szInput){
+ return puncHandle->AddPunc(szInput);
}
vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) {
@@ -231,9 +241,9 @@
auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
result = greedy_search(floatData, *encoder_out_lens);
}
- catch (...)
+ catch (std::exception const &e)
{
- result = "";
+ printf(e.what());
}
return result;
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.h b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
index b0712b4..cde2937 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.h
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -14,6 +14,7 @@
knf::FbankOptions fbank_opts;
std::unique_ptr<FsmnVad> vadHandle;
+ std::unique_ptr<CTTransformer> puncHandle;
Vocab* vocab;
vector<float> means_list;
@@ -32,7 +33,6 @@
Ort::Env env_;
Ort::SessionOptions sessionOptions;
- vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
@@ -45,6 +45,7 @@
string forward(float* din, int len, int flag);
string rescoring();
std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data);
+ string AddPunc(const char* szInput);
};
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index 74b6be3..40d8928 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -28,6 +28,8 @@
#include "ComDefine.h"
#include "commonfunc.h"
#include "predefine_coe.h"
+#include "tokenizer.h"
+#include "punc_infer.h"
#include "FsmnVad.h"
#include "e2e_vad.h"
#include "Vocab.h"
diff --git a/funasr/runtime/onnxruntime/src/punc_infer.cpp b/funasr/runtime/onnxruntime/src/punc_infer.cpp
new file mode 100644
index 0000000..8dbb49d
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/punc_infer.cpp
@@ -0,0 +1,183 @@
+#include "precomp.h"
+
+CTTransformer::CTTransformer(const char* sz_model_dir, int thread_num)
+:env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options{}
+{
+ session_options.SetIntraOpNumThreads(thread_num);
+ session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
+ session_options.DisableCpuMemArena();
+
+ string strModelPath = pathAppend(sz_model_dir, PUNC_MODEL_FILE);
+ string strYamlPath = pathAppend(sz_model_dir, PUNC_YAML_FILE);
+
+#ifdef _WIN32
+ std::wstring detPath = strToWstr(strModelPath);
+ m_session = std::make_unique<Ort::Session>(env_, detPath.c_str(), session_options);
+#else
+ m_session = std::make_unique<Ort::Session>(env_, strModelPath.c_str(), session_options);
+#endif
+ // read inputnames outputnames
+ vector<string> m_strInputNames, m_strOutputNames;
+ string strName;
+ getInputName(m_session.get(), strName);
+ m_strInputNames.push_back(strName.c_str());
+ getInputName(m_session.get(), strName, 1);
+ m_strInputNames.push_back(strName);
+
+ getOutputName(m_session.get(), strName);
+ m_strOutputNames.push_back(strName);
+
+ for (auto& item : m_strInputNames)
+ m_szInputNames.push_back(item.c_str());
+ for (auto& item : m_strOutputNames)
+ m_szOutputNames.push_back(item.c_str());
+
+ m_tokenizer.OpenYaml(strYamlPath.c_str());
+}
+
+CTTransformer::~CTTransformer()
+{
+}
+
+string CTTransformer::AddPunc(const char* sz_input)
+{
+ string strResult;
+ vector<string> strOut;
+ vector<int> InputData;
+ m_tokenizer.Tokenize(sz_input, strOut, InputData);
+
+ int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN);
+ int nCurBatch = -1;
+ int nSentEnd = -1, nLastCommaIndex = -1;
+ vector<int64_t> RemainIDs; //
+ vector<string> RemainStr; //
+ vector<int> NewPunctuation; //
+ vector<string> NewString; //
+ vector<string> NewSentenceOut;
+ vector<int> NewPuncOut;
+ int nDiff = 0;
+ for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN)
+ {
+ nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size());
+ vector<int64_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
+ vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff);
+ InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs;
+ InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr;
+
+ auto Punction = Infer(InputIDs);
+ nCurBatch = i / TOKEN_LEN;
+ if (nCurBatch < nTotalBatch - 1) // not the last minisetence
+ {
+ nSentEnd = -1;
+ nLastCommaIndex = -1;
+ for (int nIndex = Punction.size() - 2; nIndex > 0; nIndex--)
+ {
+ if (m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(PERIOD_INDEX) || m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(QUESTION_INDEX))
+ {
+ nSentEnd = nIndex;
+ break;
+ }
+ if (nLastCommaIndex < 0 && m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(COMMA_INDEX))
+ {
+ nLastCommaIndex = nIndex;
+ }
+ }
+ if (nSentEnd < 0 && InputStr.size() > CACHE_POP_TRIGGER_LIMIT && nLastCommaIndex > 0)
+ {
+ nSentEnd = nLastCommaIndex;
+ Punction[nSentEnd] = PERIOD_INDEX;
+ }
+ RemainStr.assign(InputStr.begin() + nSentEnd + 1, InputStr.end());
+ RemainIDs.assign(InputIDs.begin() + nSentEnd + 1, InputIDs.end());
+ InputStr.assign(InputStr.begin(), InputStr.begin() + nSentEnd + 1); // minit_sentence
+ Punction.assign(Punction.begin(), Punction.begin() + nSentEnd + 1);
+ }
+
+ NewPunctuation.insert(NewPunctuation.end(), Punction.begin(), Punction.end());
+ vector<string> WordWithPunc;
+ for (int i = 0; i < InputStr.size(); i++)
+ {
+ if (i > 0 && !(InputStr[i][0] & 0x80) && (i + 1) <InputStr.size() && !(InputStr[i+1][0] & 0x80))// 锟叫硷拷锟接拷模锟�
+ {
+ InputStr[i] = InputStr[i]+ " ";
+ }
+ WordWithPunc.push_back(InputStr[i]);
+
+ if (Punction[i] != NOTPUNC_INDEX) // 锟铰伙拷锟斤拷
+ {
+ WordWithPunc.push_back(m_tokenizer.ID2Punc(Punction[i]));
+ }
+ }
+
+ NewString.insert(NewString.end(), WordWithPunc.begin(), WordWithPunc.end()); // new_mini_sentence += "".join(words_with_punc)
+ NewSentenceOut = NewString;
+ NewPuncOut = NewPunctuation;
+ // last mini sentence
+ if(nCurBatch == nTotalBatch - 1)
+ {
+ if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(COMMA_INDEX) || NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(DUN_INDEX))
+ {
+ NewSentenceOut.assign(NewString.begin(), NewString.end() - 1);
+ NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX));
+ NewPuncOut.assign(NewPunctuation.begin(), NewPunctuation.end() - 1);
+ NewPuncOut.push_back(PERIOD_INDEX);
+ }
+ else if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(PERIOD_INDEX) && NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(QUESTION_INDEX))
+ {
+ NewSentenceOut = NewString;
+ NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX));
+ NewPuncOut = NewPunctuation;
+ NewPuncOut.push_back(PERIOD_INDEX);
+ }
+ }
+ }
+ for (auto& item : NewSentenceOut)
+ strResult += item;
+ return strResult;
+}
+
+vector<int> CTTransformer::Infer(vector<int64_t> input_data)
+{
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+ vector<int> punction;
+ std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()};
+ Ort::Value onnx_input = Ort::Value::CreateTensor<int64_t>(m_memoryInfo,
+ input_data.data(),
+ input_data.size(),
+ input_shape_.data(),
+ input_shape_.size());
+
+ std::array<int32_t,1> text_lengths{ (int32_t)input_data.size() };
+ std::array<int64_t,1> text_lengths_dim{ 1 };
+ Ort::Value onnx_text_lengths = Ort::Value::CreateTensor(
+ m_memoryInfo,
+ text_lengths.data(),
+ text_lengths.size() * sizeof(int32_t),
+ text_lengths_dim.data(),
+ text_lengths_dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
+ std::vector<Ort::Value> input_onnx;
+ input_onnx.emplace_back(std::move(onnx_input));
+ input_onnx.emplace_back(std::move(onnx_text_lengths));
+
+ try {
+ auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
+ std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
+
+ int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
+ float * floatData = outputTensor[0].GetTensorMutableData<float>();
+
+ for (int i = 0; i < outputCount; i += CANDIDATE_NUM)
+ {
+ int index = argmax(floatData + i, floatData + i + CANDIDATE_NUM-1);
+ punction.push_back(index);
+ }
+ }
+ catch (std::exception const &e)
+ {
+ printf(e.what());
+ }
+ return punction;
+}
+
+
+
diff --git a/funasr/runtime/onnxruntime/src/punc_infer.h b/funasr/runtime/onnxruntime/src/punc_infer.h
new file mode 100644
index 0000000..e4ef0aa
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/punc_infer.h
@@ -0,0 +1,25 @@
+#pragma once
+
+class CTTransformer {
+/**
+ * Author: Speech Lab of DAMO Academy, Alibaba Group
+ * CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ * https://arxiv.org/pdf/2003.01309.pdf
+*/
+
+private:
+
+ CTokenizer m_tokenizer;
+ vector<const char*> m_szInputNames;
+ vector<const char*> m_szOutputNames;
+
+ std::shared_ptr<Ort::Session> m_session;
+ Ort::Env env_;
+ Ort::SessionOptions session_options;
+public:
+
+ CTTransformer(const char* sz_model_dir, int thread_num);
+ ~CTTransformer();
+ vector<int> Infer(vector<int64_t> input_data);
+ string AddPunc(const char* sz_input);
+};
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.cpp b/funasr/runtime/onnxruntime/src/tokenizer.cpp
new file mode 100644
index 0000000..324def7
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -0,0 +1,208 @@
+ #include "precomp.h"
+
+CTokenizer::CTokenizer(const char* szYmlFile):m_Ready(false)
+{
+ OpenYaml(szYmlFile);
+}
+
+CTokenizer::CTokenizer():m_Ready(false)
+{
+}
+
+void CTokenizer::read_yml(const YAML::Node& node)
+{
+ if (node.IsMap())
+ {//锟斤拷map锟斤拷
+ for (auto it = node.begin(); it != node.end(); ++it)
+ {
+ read_yml(it->second);
+ }
+ }
+ if (node.IsSequence()) {//锟斤拷锟斤拷锟斤拷锟斤拷
+ for (size_t i = 0; i < node.size(); ++i) {
+ read_yml(node[i]);
+ }
+ }
+ if (node.IsScalar()) {//锟角憋拷锟斤拷锟斤拷
+ cout << node.as<string>() << endl;
+ }
+}
+
+bool CTokenizer::OpenYaml(const char* szYmlFile)
+{
+ YAML::Node m_Config = YAML::LoadFile(szYmlFile);
+ if (m_Config.IsNull())
+ return false;
+ try
+ {
+ auto Tokens = m_Config["token_list"];
+ if (Tokens.IsSequence())
+ {
+ for (size_t i = 0; i < Tokens.size(); ++i)
+ {
+ if (Tokens[i].IsScalar())
+ {
+ m_ID2Token.push_back(Tokens[i].as<string>());
+ m_Token2ID.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
+ }
+ }
+ }
+ auto Puncs = m_Config["punc_list"];
+ if (Puncs.IsSequence())
+ {
+ for (size_t i = 0; i < Puncs.size(); ++i)
+ {
+ if (Puncs[i].IsScalar())
+ {
+ m_ID2Punc.push_back(Puncs[i].as<string>());
+ m_Punc2ID.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
+ }
+ }
+ }
+ }
+ catch (YAML::BadFile& e) {
+ std::cout << "read error!" << std::endl;
+ return false;
+ }
+ m_Ready = true;
+ return m_Ready;
+}
+
+vector<string> CTokenizer::ID2String(vector<int> Input)
+{
+ vector<string> result;
+ for (auto& item : Input)
+ {
+ result.push_back(m_ID2Token[item]);
+ }
+ return result;
+}
+
+int CTokenizer::String2ID(string Input)
+{
+ int nID = 0; // <blank>
+ if (m_Token2ID.find(Input) != m_Token2ID.end())
+ nID=(m_Token2ID[Input]);
+ else
+ nID=(m_Token2ID[UNK_CHAR]);
+ return nID;
+}
+
+vector<int> CTokenizer::String2IDs(vector<string> Input)
+{
+ vector<int> result;
+ for (auto& item : Input)
+ {
+ transform(item.begin(), item.end(), item.begin(), ::tolower);
+ if (m_Token2ID.find(item) != m_Token2ID.end())
+ result.push_back(m_Token2ID[item]);
+ else
+ result.push_back(m_Token2ID[UNK_CHAR]);
+ }
+ return result;
+}
+
+vector<string> CTokenizer::ID2Punc(vector<int> Input)
+{
+ vector<string> result;
+ for (auto& item : Input)
+ {
+ result.push_back(m_ID2Punc[item]);
+ }
+ return result;
+}
+
+string CTokenizer::ID2Punc(int nPuncID)
+{
+ return m_ID2Punc[nPuncID];
+}
+
+vector<int> CTokenizer::Punc2IDs(vector<string> Input)
+{
+ vector<int> result;
+ for (auto& item : Input)
+ {
+ result.push_back(m_Punc2ID[item]);
+ }
+ return result;
+}
+
+vector<string> CTokenizer::SplitChineseString(const string & strInfo)
+{
+ vector<string> list;
+ int strSize = strInfo.size();
+ int i = 0;
+
+ while (i < strSize) {
+ int len = 1;
+ for (int j = 0; j < 6 && (strInfo[i] & (0x80 >> j)); j++) {
+ len = j + 1;
+ }
+ list.push_back(strInfo.substr(i, len));
+ i += len;
+ }
+ return list;
+}
+
+void CTokenizer::strSplit(const string& str, const char split, vector<string>& res)
+{
+ if (str == "")
+ {
+ return;
+ }
+ string&& strs = str + split;
+ size_t pos = strs.find(split);
+
+ while (pos != string::npos)
+ {
+ res.emplace_back(strs.substr(0, pos));
+ strs = move(strs.substr(pos + 1, strs.size()));
+ pos = strs.find(split);
+ }
+}
+
+ void CTokenizer::Tokenize(const char* strInfo, vector<string> & strOut, vector<int> & IDOut)
+{
+ vector<string> strList;
+ strSplit(strInfo,' ', strList);
+ string current_eng,current_chinese;
+ for (auto& item : strList)
+ {
+ current_eng = "";
+ current_chinese = "";
+ for (auto& ch : item)
+ {
+ if (!(ch& 0x80))
+ { // 英锟斤拷
+ if (current_chinese.size() > 0)
+ {
+ // for utf-8 chinese
+ auto chineseList = SplitChineseString(current_chinese);
+ strOut.insert(strOut.end(), chineseList.begin(),chineseList.end());
+ current_chinese = "";
+ }
+ current_eng += ch;
+ }
+ else
+ {
+ if (current_eng.size() > 0)
+ {
+ strOut.push_back(current_eng);
+ current_eng = "";
+ }
+ current_chinese += ch;
+ }
+ }
+ if (current_chinese.size() > 0)
+ {
+ auto chineseList = SplitChineseString(current_chinese);
+ strOut.insert(strOut.end(), chineseList.begin(), chineseList.end());
+ current_chinese = "";
+ }
+ if (current_eng.size() > 0)
+ {
+ strOut.push_back(current_eng);
+ }
+ }
+ IDOut= String2IDs(strOut);
+}
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.h b/funasr/runtime/onnxruntime/src/tokenizer.h
new file mode 100644
index 0000000..d8424a2
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/tokenizer.h
@@ -0,0 +1,27 @@
+#pragma once
+#include "yaml-cpp/yaml.h"
+
+class CTokenizer {
+private:
+
+ bool m_Ready = false;
+ vector<string> m_ID2Token,m_ID2Punc;
+ map<string, int> m_Token2ID,m_Punc2ID;
+
+public:
+
+ CTokenizer(const char* szYmlFile);
+ CTokenizer();
+ bool OpenYaml(const char* szYmlFile);
+ void read_yml(const YAML::Node& node);
+ vector<string> ID2String(vector<int> Input);
+ vector<int> String2IDs(vector<string> Input);
+ int String2ID(string Input);
+ vector<string> ID2Punc(vector<int> Input);
+ string ID2Punc(int nPuncID);
+ vector<int> Punc2IDs(vector<string> Input);
+ vector<string> SplitChineseString(const string& strInfo);
+ void strSplit(const string& str, const char split, vector<string>& res);
+ void Tokenize(const char* strInfo, vector<string>& strOut, vector<int>& IDOut);
+
+};
--
Gitblit v1.9.1