From 55708e7cebaedefc5f69d61f157993da41848b8f Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期日, 23 四月 2023 19:06:25 +0800
Subject: [PATCH] add offline punc for onnxruntime

---
 funasr/runtime/onnxruntime/include/Model.h                                  |    1 
 funasr/runtime/onnxruntime/src/punc_infer.h                                 |   25 +
 funasr/runtime/onnxruntime/src/tokenizer.cpp                                |  208 ++++++++++++++++
 funasr/runtime/onnxruntime/src/paraformer_onnx.cpp                          |   14 
 funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/online-feature.h |  142 ++++++++++
 funasr/runtime/onnxruntime/src/commonfunc.h                                 |    7 
 funasr/runtime/onnxruntime/src/punc_infer.cpp                               |  183 ++++++++++++++
 funasr/runtime/onnxruntime/src/tokenizer.h                                  |   27 ++
 funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/feature-fbank.h  |  134 ++++++++++
 funasr/runtime/onnxruntime/src/libfunasrapi.cpp                             |    4 
 funasr/runtime/onnxruntime/include/ComDefine.h                              |   21 +
 funasr/runtime/onnxruntime/src/paraformer_onnx.h                            |    3 
 funasr/runtime/onnxruntime/src/precomp.h                                    |    2 
 13 files changed, 767 insertions(+), 4 deletions(-)

diff --git a/funasr/runtime/onnxruntime/include/ComDefine.h b/funasr/runtime/onnxruntime/include/ComDefine.h
index 6929e49..72a843d 100644
--- a/funasr/runtime/onnxruntime/include/ComDefine.h
+++ b/funasr/runtime/onnxruntime/include/ComDefine.h
@@ -12,6 +12,7 @@
 #define MODEL_SAMPLE_RATE 16000
 #endif
 
+// vad
 #ifndef VAD_SILENCE_DYRATION
 #define VAD_SILENCE_DYRATION 15000
 #endif
@@ -24,5 +25,25 @@
 #define VAD_SPEECH_NOISE_THRES 0.9
 #endif
 
+// punc
+#define PUNC_MODEL_FILE  "punc_model.onnx"
+#define PUNC_YAML_FILE "punc.yaml"
+
+#define UNK_CHAR "<unk>"
+
+#define  INPUT_NUM  2
+#define  INPUT_NAME1 "input"
+#define  INPUT_NAME2 "text_lengths"
+#define  OUTPUT_NAME "logits"
+#define  TOKEN_LEN     20
+
+#define  CANDIDATE_NUM   6
+#define UNKNOW_INDEX 0
+#define NOTPUNC_INDEX 1
+#define COMMA_INDEX 2
+#define PERIOD_INDEX 3
+#define QUESTION_INDEX 4
+#define DUN_INDEX 5
+#define  CACHE_POP_TRIGGER_LIMIT   200
 
 #endif
diff --git a/funasr/runtime/onnxruntime/include/Model.h b/funasr/runtime/onnxruntime/include/Model.h
index cd3b0a3..f92789f 100644
--- a/funasr/runtime/onnxruntime/include/Model.h
+++ b/funasr/runtime/onnxruntime/include/Model.h
@@ -12,6 +12,7 @@
     virtual std::string forward(float *din, int len, int flag) = 0;
     virtual std::string rescoring() = 0;
     virtual std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data)=0;
+    virtual std::string AddPunc(const char* szInput)=0;
 };
 
 Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false);
diff --git a/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/feature-fbank.h b/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/feature-fbank.h
new file mode 100644
index 0000000..0786aad
--- /dev/null
+++ b/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/feature-fbank.h
@@ -0,0 +1,134 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// This file is copied/modified from kaldi/src/feat/feature-fbank.h
+
+#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
+#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "kaldi-native-fbank/csrc/feature-window.h"
+#include "kaldi-native-fbank/csrc/mel-computations.h"
+#include "kaldi-native-fbank/csrc/rfft.h"
+
+namespace knf {
+
+struct FbankOptions {
+  FrameExtractionOptions frame_opts;
+  MelBanksOptions mel_opts;
+  // append an extra dimension with energy to the filter banks
+  bool use_energy = false;
+  float energy_floor = 0.0f;  // active iff use_energy==true
+
+  // If true, compute log_energy before preemphasis and windowing
+  // If false, compute log_energy after preemphasis ans windowing
+  bool raw_energy = true;  // active iff use_energy==true
+
+  // If true, put energy last (if using energy)
+  // If false, put energy first
+  bool htk_compat = false;  // active iff use_energy==true
+
+  // if true (default), produce log-filterbank, else linear
+  bool use_log_fbank = true;
+
+  // if true (default), use power in filterbank
+  // analysis, else magnitude.
+  bool use_power = true;
+
+  FbankOptions() { mel_opts.num_bins = 23; }
+
+  std::string ToString() const {
+    std::ostringstream os;
+    os << "frame_opts: \n";
+    os << frame_opts << "\n";
+    os << "\n";
+
+    os << "mel_opts: \n";
+    os << mel_opts << "\n";
+
+    os << "use_energy: " << use_energy << "\n";
+    os << "energy_floor: " << energy_floor << "\n";
+    os << "raw_energy: " << raw_energy << "\n";
+    os << "htk_compat: " << htk_compat << "\n";
+    os << "use_log_fbank: " << use_log_fbank << "\n";
+    os << "use_power: " << use_power << "\n";
+    return os.str();
+  }
+};
+
+std::ostream &operator<<(std::ostream &os, const FbankOptions &opts);
+
+class FbankComputer {
+ public:
+  using Options = FbankOptions;
+
+  explicit FbankComputer(const FbankOptions &opts);
+  ~FbankComputer();
+
+  int32_t Dim() const {
+    return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
+  }
+
+  // if true, compute log_energy_pre_window but after dithering and dc removal
+  bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
+
+  const FrameExtractionOptions &GetFrameOptions() const {
+    return opts_.frame_opts;
+  }
+
+  const FbankOptions &GetOptions() const { return opts_; }
+
+  /**
+     Function that computes one frame of features from
+     one frame of signal.
+
+     @param [in] signal_raw_log_energy The log-energy of the frame of the signal
+         prior to windowing and pre-emphasis, or
+         log(numeric_limits<float>::min()), whichever is greater.  Must be
+         ignored by this function if this class returns false from
+         this->NeedsRawLogEnergy().
+     @param [in] vtln_warp  The VTLN warping factor that the user wants
+         to be applied when computing features for this utterance.  Will
+         normally be 1.0, meaning no warping is to be done.  The value will
+         be ignored for feature types that don't support VLTN, such as
+         spectrogram features.
+     @param [in] signal_frame  One frame of the signal,
+       as extracted using the function ExtractWindow() using the options
+       returned by this->GetFrameOptions().  The function will use the
+       vector as a workspace, which is why it's a non-const pointer.
+     @param [out] feature  Pointer to a vector of size this->Dim(), to which
+         the computed feature will be written. It should be pre-allocated.
+  */
+  void Compute(float signal_raw_log_energy, float vtln_warp,
+               std::vector<float> *signal_frame, float *feature);
+
+ private:
+  const MelBanks *GetMelBanks(float vtln_warp);
+
+  FbankOptions opts_;
+  float log_energy_floor_;
+  std::map<float, MelBanks *> mel_banks_;  // float is VTLN coefficient.
+  Rfft rfft_;
+};
+
+}  // namespace knf
+
+#endif  // KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
diff --git a/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/online-feature.h b/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/online-feature.h
new file mode 100644
index 0000000..5ca5511
--- /dev/null
+++ b/funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/online-feature.h
@@ -0,0 +1,142 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// The content in this file is copied/modified from
+// This file is copied/modified from kaldi/src/feat/online-feature.h
+#ifndef KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_
+#define KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_
+
+#include <cstdint>
+#include <deque>
+#include <vector>
+
+#include "kaldi-native-fbank/csrc/feature-fbank.h"
+
+namespace knf {
+
+/// This class serves as a storage for feature vectors with an option to limit
+/// the memory usage by removing old elements. The deleted frames indices are
+/// "remembered" so that regardless of the MAX_ITEMS setting, the user always
+/// provides the indices as if no deletion was being performed.
+/// This is useful when processing very long recordings which would otherwise
+/// cause the memory to eventually blow up when the features are not being
+/// removed.
+class RecyclingVector {
+ public:
+  /// By default it does not remove any elements.
+  explicit RecyclingVector(int32_t items_to_hold = -1);
+
+  ~RecyclingVector() = default;
+  RecyclingVector(const RecyclingVector &) = delete;
+  RecyclingVector &operator=(const RecyclingVector &) = delete;
+
+  // The pointer is owned by RecyclingVector
+  // Users should not free it
+  const float *At(int32_t index) const;
+
+  void PushBack(std::vector<float> item);
+
+  /// This method returns the size as if no "recycling" had happened,
+  /// i.e. equivalent to the number of times the PushBack method has been
+  /// called.
+  int32_t Size() const;
+
+ private:
+  std::deque<std::vector<float>> items_;
+  int32_t items_to_hold_;
+  int32_t first_available_index_;
+};
+
+/// This is a templated class for online feature extraction;
+/// it's templated on a class like MfccComputer or PlpComputer
+/// that does the basic feature extraction.
+template <class C>
+class OnlineGenericBaseFeature {
+ public:
+  // Constructor from options class
+  explicit OnlineGenericBaseFeature(const typename C::Options &opts);
+
+  int32_t Dim() const { return computer_.Dim(); }
+
+  float FrameShiftInSeconds() const {
+    return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
+  }
+
+  int32_t NumFramesReady() const { return features_.Size(); }
+
+  // Note: IsLastFrame() will only ever return true if you have called
+  // InputFinished() (and this frame is the last frame).
+  bool IsLastFrame(int32_t frame) const {
+    return input_finished_ && frame == NumFramesReady() - 1;
+  }
+
+  const float *GetFrame(int32_t frame) const { return features_.At(frame); }
+
+  // This would be called from the application, when you get
+  // more wave data.  Note: the sampling_rate is only provided so
+  // the code can assert that it matches the sampling rate
+  // expected in the options.
+  //
+  // @param sampling_rate The sampling_rate of the input waveform
+  // @param waveform Pointer to a 1-D array of size n
+  // @param n Number of entries in waveform
+  void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
+
+  // InputFinished() tells the class you won't be providing any
+  // more waveform.  This will help flush out the last frame or two
+  // of features, in the case where snip-edges == false; it also
+  // affects the return value of IsLastFrame().
+  void InputFinished();
+
+ private:
+  // This function computes any additional feature frames that it is possible to
+  // compute from 'waveform_remainder_', which at this point may contain more
+  // than just a remainder-sized quantity (because AcceptWaveform() appends to
+  // waveform_remainder_ before calling this function).  It adds these feature
+  // frames to features_, and shifts off any now-unneeded samples of input from
+  // waveform_remainder_ while incrementing waveform_offset_ by the same amount.
+  void ComputeFeatures();
+
+  C computer_;  // class that does the MFCC or PLP or filterbank computation
+
+  FeatureWindowFunction window_function_;
+
+  // features_ is the Mfcc or Plp or Fbank features that we have already
+  // computed.
+
+  RecyclingVector features_;
+
+  // True if the user has called "InputFinished()"
+  bool input_finished_;
+
+  // waveform_offset_ is the number of samples of waveform that we have
+  // already discarded, i.e. that were prior to 'waveform_remainder_'.
+  int64_t waveform_offset_;
+
+  // waveform_remainder_ is a short piece of waveform that we may need to keep
+  // after extracting all the whole frames we can (whatever length of feature
+  // will be required for the next phase of computation).
+  // It is a 1-D tensor
+  std::vector<float> waveform_remainder_;
+};
+
+using OnlineFbank = OnlineGenericBaseFeature<FbankComputer>;
+
+}  // namespace knf
+
+#endif  // KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_
diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h
index 8d1a97c..cae1bd7 100644
--- a/funasr/runtime/onnxruntime/src/commonfunc.h
+++ b/funasr/runtime/onnxruntime/src/commonfunc.h
@@ -1,5 +1,5 @@
 #pragma once 
-
+#include <algorithm>
 typedef struct
 {
     std::string msg;
@@ -49,3 +49,8 @@
         }
     }
 }
+
+template <class ForwardIterator>
+inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
+    return std::distance(first, std::max_element(first, last));
+}
diff --git a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
index f15e86f..0adef89 100644
--- a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
+++ b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -133,6 +133,10 @@
 			if (fnCallback)
 				fnCallback(nStep, nTotal);
 		}
+		if(true){
+			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
+			pResult->msg = punc_res;
+		}
 	
 		return pResult;
 	}
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index 1a86da6..69d1554 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -17,6 +17,11 @@
         vadHandle->init_vad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
     }
 
+    // PUNC model
+    if(true){
+        puncHandle = make_unique<CTTransformer>(path, nNumThread);
+    }
+
     if(quantize)
     {
         model_path = pathAppend(path, "model_quant.onnx");
@@ -50,6 +55,7 @@
     m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
 #endif
 
+    vector<string> m_strInputNames, m_strOutputNames;
     string strName;
     getInputName(m_session.get(), strName);
     m_strInputNames.push_back(strName.c_str());
@@ -81,6 +87,10 @@
 
 vector<std::vector<int>> ModelImp::vad_seg(std::vector<float>& pcm_data){
     return vadHandle->infer(pcm_data);
+}
+
+string ModelImp::AddPunc(const char* szInput){
+    return puncHandle->AddPunc(szInput);
 }
 
 vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) {
@@ -231,9 +241,9 @@
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
         result = greedy_search(floatData, *encoder_out_lens);
     }
-    catch (...)
+    catch (std::exception const &e)
     {
-        result = "";
+        printf(e.what());
     }
 
     return result;
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.h b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
index b0712b4..cde2937 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.h
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -14,6 +14,7 @@
         knf::FbankOptions fbank_opts;
 
         std::unique_ptr<FsmnVad> vadHandle;
+        std::unique_ptr<CTTransformer> puncHandle;
 
         Vocab* vocab;
         vector<float> means_list;
@@ -32,7 +33,6 @@
         Ort::Env env_;
         Ort::SessionOptions sessionOptions;
 
-        vector<string> m_strInputNames, m_strOutputNames;
         vector<const char*> m_szInputNames;
         vector<const char*> m_szOutputNames;
 
@@ -45,6 +45,7 @@
         string forward(float* din, int len, int flag);
         string rescoring();
         std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data);
+        string AddPunc(const char* szInput);
 
     };
 
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index 74b6be3..40d8928 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -28,6 +28,8 @@
 #include "ComDefine.h"
 #include "commonfunc.h"
 #include "predefine_coe.h"
+#include "tokenizer.h"
+#include "punc_infer.h"
 #include "FsmnVad.h"
 #include "e2e_vad.h"
 #include "Vocab.h"
diff --git a/funasr/runtime/onnxruntime/src/punc_infer.cpp b/funasr/runtime/onnxruntime/src/punc_infer.cpp
new file mode 100644
index 0000000..8dbb49d
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/punc_infer.cpp
@@ -0,0 +1,183 @@
+#include "precomp.h"
+
+CTTransformer::CTTransformer(const char* sz_model_dir, int thread_num)
+:env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options{}
+{
+    session_options.SetIntraOpNumThreads(thread_num);
+    session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
+    session_options.DisableCpuMemArena();
+
+	string strModelPath = pathAppend(sz_model_dir, PUNC_MODEL_FILE);
+	string strYamlPath = pathAppend(sz_model_dir, PUNC_YAML_FILE);
+
+#ifdef _WIN32
+	std::wstring detPath = strToWstr(strModelPath);
+    m_session = std::make_unique<Ort::Session>(env_, detPath.c_str(), session_options);
+#else
+    m_session = std::make_unique<Ort::Session>(env_, strModelPath.c_str(), session_options);
+#endif
+    // read inputnames outputnames
+    vector<string> m_strInputNames, m_strOutputNames;
+    string strName;
+    getInputName(m_session.get(), strName);
+    m_strInputNames.push_back(strName.c_str());
+    getInputName(m_session.get(), strName, 1);
+    m_strInputNames.push_back(strName);
+    
+    getOutputName(m_session.get(), strName);
+    m_strOutputNames.push_back(strName);
+
+    for (auto& item : m_strInputNames)
+        m_szInputNames.push_back(item.c_str());
+    for (auto& item : m_strOutputNames)
+        m_szOutputNames.push_back(item.c_str());
+
+	m_tokenizer.OpenYaml(strYamlPath.c_str());
+}
+
+CTTransformer::~CTTransformer()
+{
+}
+
+string CTTransformer::AddPunc(const char* sz_input)
+{
+    string strResult;
+    vector<string> strOut;
+    vector<int> InputData;
+    m_tokenizer.Tokenize(sz_input, strOut, InputData); 
+
+    int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN);
+    int nCurBatch = -1;
+    int nSentEnd = -1, nLastCommaIndex = -1;
+    vector<int64_t> RemainIDs; // 
+    vector<string> RemainStr; //
+    vector<int> NewPunctuation; //
+    vector<string> NewString; //
+    vector<string> NewSentenceOut;
+    vector<int> NewPuncOut;
+    int nDiff = 0;
+    for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN)
+    {
+        nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size());
+        vector<int64_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
+        vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff);
+        InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs;
+        InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr;
+
+        auto Punction = Infer(InputIDs);
+        nCurBatch = i / TOKEN_LEN;
+        if (nCurBatch < nTotalBatch - 1) // not the last minisetence
+        {
+            nSentEnd = -1;
+            nLastCommaIndex = -1;
+            for (int nIndex = Punction.size() - 2; nIndex > 0; nIndex--)
+            {
+                if (m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(PERIOD_INDEX) || m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(QUESTION_INDEX))
+                {
+                    nSentEnd = nIndex;
+                    break;
+                }
+                if (nLastCommaIndex < 0 && m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(COMMA_INDEX))
+                {
+                    nLastCommaIndex = nIndex;
+                }
+            }
+            if (nSentEnd < 0 && InputStr.size() > CACHE_POP_TRIGGER_LIMIT && nLastCommaIndex > 0)
+            {
+                nSentEnd = nLastCommaIndex;
+                Punction[nSentEnd] = PERIOD_INDEX;
+            }
+            RemainStr.assign(InputStr.begin() + nSentEnd + 1, InputStr.end());
+            RemainIDs.assign(InputIDs.begin() + nSentEnd + 1, InputIDs.end());
+            InputStr.assign(InputStr.begin(), InputStr.begin() + nSentEnd + 1);  // minit_sentence
+            Punction.assign(Punction.begin(), Punction.begin() + nSentEnd + 1);
+        }
+        
+        NewPunctuation.insert(NewPunctuation.end(), Punction.begin(), Punction.end());
+        vector<string> WordWithPunc;
+        for (int i = 0; i < InputStr.size(); i++)
+        {
+            if (i > 0 && !(InputStr[i][0] & 0x80) && (i + 1) <InputStr.size() && !(InputStr[i+1][0] & 0x80))// 锟叫硷拷锟接拷模锟�
+            {
+                InputStr[i] = InputStr[i]+ " ";
+            }
+            WordWithPunc.push_back(InputStr[i]);
+
+            if (Punction[i] != NOTPUNC_INDEX) // 锟铰伙拷锟斤拷
+            {
+                WordWithPunc.push_back(m_tokenizer.ID2Punc(Punction[i]));
+            }
+        }
+
+        NewString.insert(NewString.end(), WordWithPunc.begin(), WordWithPunc.end()); // new_mini_sentence += "".join(words_with_punc)
+        NewSentenceOut = NewString;
+        NewPuncOut = NewPunctuation;
+        // last mini sentence
+        if(nCurBatch == nTotalBatch - 1)
+        {
+            if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(COMMA_INDEX) || NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(DUN_INDEX))
+            {
+                NewSentenceOut.assign(NewString.begin(), NewString.end() - 1);
+                NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX));
+                NewPuncOut.assign(NewPunctuation.begin(), NewPunctuation.end() - 1);
+                NewPuncOut.push_back(PERIOD_INDEX);
+            }
+            else if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(PERIOD_INDEX) && NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(QUESTION_INDEX))
+            {
+                NewSentenceOut = NewString;
+                NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX));
+                NewPuncOut = NewPunctuation;
+                NewPuncOut.push_back(PERIOD_INDEX);
+            }
+        }
+    }
+    for (auto& item : NewSentenceOut)
+        strResult += item;
+    return strResult;
+}
+
+vector<int> CTTransformer::Infer(vector<int64_t> input_data)
+{
+    Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+    vector<int> punction;
+    std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()};
+    Ort::Value onnx_input = Ort::Value::CreateTensor<int64_t>(m_memoryInfo,
+        input_data.data(),
+        input_data.size(),
+        input_shape_.data(),
+        input_shape_.size());
+
+    std::array<int32_t,1> text_lengths{ (int32_t)input_data.size() };
+    std::array<int64_t,1> text_lengths_dim{ 1 };
+    Ort::Value onnx_text_lengths = Ort::Value::CreateTensor(
+        m_memoryInfo,
+        text_lengths.data(),
+        text_lengths.size() * sizeof(int32_t),
+        text_lengths_dim.data(),
+        text_lengths_dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
+    std::vector<Ort::Value> input_onnx;
+    input_onnx.emplace_back(std::move(onnx_input));
+    input_onnx.emplace_back(std::move(onnx_text_lengths));
+        
+    try {
+        auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
+        std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
+
+        int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
+        float * floatData = outputTensor[0].GetTensorMutableData<float>();
+
+        for (int i = 0; i < outputCount; i += CANDIDATE_NUM)
+        {
+            int index = argmax(floatData + i, floatData + i + CANDIDATE_NUM-1);
+            punction.push_back(index);
+        }
+    }
+    catch (std::exception const &e)
+    {
+        printf(e.what());
+    }
+    return punction;
+}
+
+
+
diff --git a/funasr/runtime/onnxruntime/src/punc_infer.h b/funasr/runtime/onnxruntime/src/punc_infer.h
new file mode 100644
index 0000000..e4ef0aa
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/punc_infer.h
@@ -0,0 +1,25 @@
+#pragma once 
+
+class CTTransformer {
+/**
+ * Author: Speech Lab of DAMO Academy, Alibaba Group
+ * CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+ * https://arxiv.org/pdf/2003.01309.pdf
+*/
+
+private:
+
+	CTokenizer m_tokenizer;
+	vector<const char*> m_szInputNames;
+	vector<const char*> m_szOutputNames;
+
+	std::shared_ptr<Ort::Session> m_session;
+    Ort::Env env_;
+    Ort::SessionOptions session_options;
+public:
+
+	CTTransformer(const char* sz_model_dir, int thread_num);
+	~CTTransformer();
+	vector<int>  Infer(vector<int64_t> input_data);
+	string AddPunc(const char* sz_input);
+};
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.cpp b/funasr/runtime/onnxruntime/src/tokenizer.cpp
new file mode 100644
index 0000000..324def7
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -0,0 +1,208 @@
+ #include "precomp.h"
+
+CTokenizer::CTokenizer(const char* szYmlFile):m_Ready(false)
+{
+	OpenYaml(szYmlFile);
+}
+
+CTokenizer::CTokenizer():m_Ready(false)
+{
+}
+
+void CTokenizer::read_yml(const YAML::Node& node) 
+{
+	if (node.IsMap()) 
+	{//锟斤拷map锟斤拷
+		for (auto it = node.begin(); it != node.end(); ++it) 
+		{
+			read_yml(it->second);
+		}
+	}
+	if (node.IsSequence()) {//锟斤拷锟斤拷锟斤拷锟斤拷
+		for (size_t i = 0; i < node.size(); ++i) {
+			read_yml(node[i]);
+		}
+	}
+	if (node.IsScalar()) {//锟角憋拷锟斤拷锟斤拷
+		cout << node.as<string>() << endl;
+	}
+}
+
+bool CTokenizer::OpenYaml(const char* szYmlFile)
+{
+	YAML::Node m_Config = YAML::LoadFile(szYmlFile);
+	if (m_Config.IsNull())
+		return false;
+	try
+	{
+		auto Tokens = m_Config["token_list"];
+		if (Tokens.IsSequence())
+		{
+			for (size_t i = 0; i < Tokens.size(); ++i) 
+			{
+				if (Tokens[i].IsScalar())
+				{
+					m_ID2Token.push_back(Tokens[i].as<string>());
+					m_Token2ID.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
+				}
+			}
+		}
+		auto Puncs = m_Config["punc_list"];
+		if (Puncs.IsSequence())
+		{
+			for (size_t i = 0; i < Puncs.size(); ++i)
+			{
+				if (Puncs[i].IsScalar())
+				{ 
+					m_ID2Punc.push_back(Puncs[i].as<string>());
+					m_Punc2ID.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
+				}
+			}
+		}
+	}
+	catch (YAML::BadFile& e) {
+		std::cout << "read error!" << std::endl;
+		return  false;
+	}
+	m_Ready = true;
+	return m_Ready;
+}
+
+vector<string> CTokenizer::ID2String(vector<int> Input)
+{
+	vector<string> result;
+	for (auto& item : Input)
+	{
+		result.push_back(m_ID2Token[item]);
+	}
+	return result;
+}
+
+int CTokenizer::String2ID(string Input)
+{
+	int nID = 0; // <blank>
+	if (m_Token2ID.find(Input) != m_Token2ID.end())
+		nID=(m_Token2ID[Input]);
+	else
+		nID=(m_Token2ID[UNK_CHAR]);
+	return nID;
+}
+
+vector<int> CTokenizer::String2IDs(vector<string> Input)
+{
+	vector<int> result;
+	for (auto& item : Input)
+	{	
+		transform(item.begin(), item.end(), item.begin(), ::tolower);
+		if (m_Token2ID.find(item) != m_Token2ID.end())
+			result.push_back(m_Token2ID[item]);
+		else
+			result.push_back(m_Token2ID[UNK_CHAR]);
+	}
+	return result;
+}
+
+vector<string> CTokenizer::ID2Punc(vector<int> Input)
+{
+	vector<string> result;
+	for (auto& item : Input)
+	{
+		result.push_back(m_ID2Punc[item]);
+	}
+	return result;
+}
+
+string CTokenizer::ID2Punc(int nPuncID)
+{
+	return m_ID2Punc[nPuncID];
+}
+
+vector<int> CTokenizer::Punc2IDs(vector<string> Input)
+{
+	vector<int> result;
+	for (auto& item : Input)
+	{
+		result.push_back(m_Punc2ID[item]);
+	}
+	return result;
+}
+
+vector<string> CTokenizer::SplitChineseString(const string & strInfo)
+{
+	vector<string> list;
+	int strSize = strInfo.size();
+	int i = 0;
+
+	while (i < strSize) {
+		int len = 1;
+		for (int j = 0; j < 6 && (strInfo[i] & (0x80 >> j)); j++) {
+			len = j + 1;
+		}
+		list.push_back(strInfo.substr(i, len));
+		i += len;
+	}
+	return list;
+}
+
+void CTokenizer::strSplit(const string& str, const char split, vector<string>& res)
+{
+	if (str == "")
+	{
+		return;
+	}
+	string&& strs = str + split;
+	size_t pos = strs.find(split);
+
+	while (pos != string::npos)
+	{
+		res.emplace_back(strs.substr(0, pos));
+		strs = move(strs.substr(pos + 1, strs.size()));
+		pos = strs.find(split);
+	}
+}
+
+ void CTokenizer::Tokenize(const char* strInfo, vector<string> & strOut, vector<int> & IDOut)
+{
+	vector<string>  strList;
+	strSplit(strInfo,' ', strList);
+	string current_eng,current_chinese;
+	for (auto& item : strList)
+	{
+		current_eng = "";
+		current_chinese = "";
+		for (auto& ch : item)
+		{
+			if (!(ch& 0x80))
+			{ // 英锟斤拷
+				if (current_chinese.size() > 0)
+				{
+					// for utf-8 chinese
+					auto chineseList = SplitChineseString(current_chinese);
+					strOut.insert(strOut.end(), chineseList.begin(),chineseList.end());
+					current_chinese = "";
+				}
+				current_eng += ch;
+			}
+			else
+			{
+				if (current_eng.size() > 0)
+				{
+					strOut.push_back(current_eng);
+					current_eng = "";
+				}
+				current_chinese += ch;
+			}
+		}
+		if (current_chinese.size() > 0)
+		{
+			auto chineseList = SplitChineseString(current_chinese);
+			strOut.insert(strOut.end(), chineseList.begin(), chineseList.end());
+			current_chinese = "";
+		}
+		if (current_eng.size() > 0)
+		{
+			strOut.push_back(current_eng);
+		}
+	}
+	IDOut= String2IDs(strOut);
+}
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.h b/funasr/runtime/onnxruntime/src/tokenizer.h
new file mode 100644
index 0000000..d8424a2
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/tokenizer.h
@@ -0,0 +1,27 @@
+#pragma once
+#include "yaml-cpp/yaml.h"
+
+class CTokenizer {
+private:
+
+	bool  m_Ready = false;
+	vector<string>   m_ID2Token,m_ID2Punc;
+	map<string, int>  m_Token2ID,m_Punc2ID;
+
+public:
+
+	CTokenizer(const char* szYmlFile);
+	CTokenizer();
+	bool OpenYaml(const char* szYmlFile);
+	void read_yml(const YAML::Node& node);
+	vector<string> ID2String(vector<int> Input);
+	vector<int> String2IDs(vector<string> Input);
+	int String2ID(string Input);
+	vector<string> ID2Punc(vector<int> Input);
+	string ID2Punc(int nPuncID);
+	vector<int> Punc2IDs(vector<string> Input);
+	vector<string> SplitChineseString(const string& strInfo);
+	void strSplit(const string& str, const char split, vector<string>& res);
+	void Tokenize(const char* strInfo, vector<string>& strOut, vector<int>& IDOut);
+
+};

--
Gitblit v1.9.1