lyblsgo
2023-04-23 55708e7cebaedefc5f69d61f157993da41848b8f
add offline punc for onnxruntime
7个文件已修改
6个文件已添加
771 ■■■■■ 已修改文件
funasr/runtime/onnxruntime/include/ComDefine.h 21 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/Model.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/feature-fbank.h 134 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/online-feature.h 142 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/commonfunc.h 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/libfunasrapi.cpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer_onnx.h 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/precomp.h 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/punc_infer.cpp 183 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/punc_infer.h 25 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/tokenizer.cpp 208 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/tokenizer.h 27 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/ComDefine.h
@@ -12,6 +12,7 @@
#define MODEL_SAMPLE_RATE 16000
#endif
// vad
#ifndef VAD_SILENCE_DYRATION
#define VAD_SILENCE_DYRATION 15000
#endif
@@ -24,5 +25,25 @@
#define VAD_SPEECH_NOISE_THRES 0.9
#endif
// punc
#define PUNC_MODEL_FILE  "punc_model.onnx"
#define PUNC_YAML_FILE "punc.yaml"
#define UNK_CHAR "<unk>"
#define  INPUT_NUM  2
#define  INPUT_NAME1 "input"
#define  INPUT_NAME2 "text_lengths"
#define  OUTPUT_NAME "logits"
#define  TOKEN_LEN     20
#define  CANDIDATE_NUM   6
#define UNKNOW_INDEX 0
#define NOTPUNC_INDEX 1
#define COMMA_INDEX 2
#define PERIOD_INDEX 3
#define QUESTION_INDEX 4
#define DUN_INDEX 5
#define  CACHE_POP_TRIGGER_LIMIT   200
#endif
funasr/runtime/onnxruntime/include/Model.h
@@ -12,6 +12,7 @@
    virtual std::string forward(float *din, int len, int flag) = 0;
    virtual std::string rescoring() = 0;
    virtual std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data)=0;
    virtual std::string AddPunc(const char* szInput)=0;
};
Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false);
funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/feature-fbank.h
New file
@@ -0,0 +1,134 @@
/**
 * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
 *
 * See LICENSE for clarification regarding multiple authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
// This file is copied/modified from kaldi/src/feat/feature-fbank.h
#ifndef KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
#define KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
#include <map>
#include <string>
#include <vector>
#include "kaldi-native-fbank/csrc/feature-window.h"
#include "kaldi-native-fbank/csrc/mel-computations.h"
#include "kaldi-native-fbank/csrc/rfft.h"
namespace knf {
struct FbankOptions {
  FrameExtractionOptions frame_opts;
  MelBanksOptions mel_opts;
  // append an extra dimension with energy to the filter banks
  bool use_energy = false;
  float energy_floor = 0.0f;  // active iff use_energy==true
  // If true, compute log_energy before preemphasis and windowing
  // If false, compute log_energy after preemphasis ans windowing
  bool raw_energy = true;  // active iff use_energy==true
  // If true, put energy last (if using energy)
  // If false, put energy first
  bool htk_compat = false;  // active iff use_energy==true
  // if true (default), produce log-filterbank, else linear
  bool use_log_fbank = true;
  // if true (default), use power in filterbank
  // analysis, else magnitude.
  bool use_power = true;
  FbankOptions() { mel_opts.num_bins = 23; }
  std::string ToString() const {
    std::ostringstream os;
    os << "frame_opts: \n";
    os << frame_opts << "\n";
    os << "\n";
    os << "mel_opts: \n";
    os << mel_opts << "\n";
    os << "use_energy: " << use_energy << "\n";
    os << "energy_floor: " << energy_floor << "\n";
    os << "raw_energy: " << raw_energy << "\n";
    os << "htk_compat: " << htk_compat << "\n";
    os << "use_log_fbank: " << use_log_fbank << "\n";
    os << "use_power: " << use_power << "\n";
    return os.str();
  }
};
std::ostream &operator<<(std::ostream &os, const FbankOptions &opts);
class FbankComputer {
 public:
  using Options = FbankOptions;
  explicit FbankComputer(const FbankOptions &opts);
  ~FbankComputer();
  int32_t Dim() const {
    return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0);
  }
  // if true, compute log_energy_pre_window but after dithering and dc removal
  bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; }
  const FrameExtractionOptions &GetFrameOptions() const {
    return opts_.frame_opts;
  }
  const FbankOptions &GetOptions() const { return opts_; }
  /**
     Function that computes one frame of features from
     one frame of signal.
     @param [in] signal_raw_log_energy The log-energy of the frame of the signal
         prior to windowing and pre-emphasis, or
         log(numeric_limits<float>::min()), whichever is greater.  Must be
         ignored by this function if this class returns false from
         this->NeedsRawLogEnergy().
     @param [in] vtln_warp  The VTLN warping factor that the user wants
         to be applied when computing features for this utterance.  Will
         normally be 1.0, meaning no warping is to be done.  The value will
         be ignored for feature types that don't support VLTN, such as
         spectrogram features.
     @param [in] signal_frame  One frame of the signal,
       as extracted using the function ExtractWindow() using the options
       returned by this->GetFrameOptions().  The function will use the
       vector as a workspace, which is why it's a non-const pointer.
     @param [out] feature  Pointer to a vector of size this->Dim(), to which
         the computed feature will be written. It should be pre-allocated.
  */
  void Compute(float signal_raw_log_energy, float vtln_warp,
               std::vector<float> *signal_frame, float *feature);
 private:
  const MelBanks *GetMelBanks(float vtln_warp);
  FbankOptions opts_;
  float log_energy_floor_;
  std::map<float, MelBanks *> mel_banks_;  // float is VTLN coefficient.
  Rfft rfft_;
};
}  // namespace knf
#endif  // KALDI_NATIVE_FBANK_CSRC_FEATURE_FBANK_H_
funasr/runtime/onnxruntime/include/kaldi-native-fbank/csrc/online-feature.h
New file
@@ -0,0 +1,142 @@
/**
 * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
 *
 * See LICENSE for clarification regarding multiple authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
// The content in this file is copied/modified from
// This file is copied/modified from kaldi/src/feat/online-feature.h
#ifndef KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_
#define KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_
#include <cstdint>
#include <deque>
#include <vector>
#include "kaldi-native-fbank/csrc/feature-fbank.h"
namespace knf {
/// This class serves as a storage for feature vectors with an option to limit
/// the memory usage by removing old elements. The deleted frames indices are
/// "remembered" so that regardless of the MAX_ITEMS setting, the user always
/// provides the indices as if no deletion was being performed.
/// This is useful when processing very long recordings which would otherwise
/// cause the memory to eventually blow up when the features are not being
/// removed.
class RecyclingVector {
 public:
  /// By default it does not remove any elements.
  explicit RecyclingVector(int32_t items_to_hold = -1);
  ~RecyclingVector() = default;
  RecyclingVector(const RecyclingVector &) = delete;
  RecyclingVector &operator=(const RecyclingVector &) = delete;
  // The pointer is owned by RecyclingVector
  // Users should not free it
  const float *At(int32_t index) const;
  void PushBack(std::vector<float> item);
  /// This method returns the size as if no "recycling" had happened,
  /// i.e. equivalent to the number of times the PushBack method has been
  /// called.
  int32_t Size() const;
 private:
  std::deque<std::vector<float>> items_;
  int32_t items_to_hold_;
  int32_t first_available_index_;
};
/// This is a templated class for online feature extraction;
/// it's templated on a class like MfccComputer or PlpComputer
/// that does the basic feature extraction.
template <class C>
class OnlineGenericBaseFeature {
 public:
  // Constructor from options class
  explicit OnlineGenericBaseFeature(const typename C::Options &opts);
  int32_t Dim() const { return computer_.Dim(); }
  float FrameShiftInSeconds() const {
    return computer_.GetFrameOptions().frame_shift_ms / 1000.0f;
  }
  int32_t NumFramesReady() const { return features_.Size(); }
  // Note: IsLastFrame() will only ever return true if you have called
  // InputFinished() (and this frame is the last frame).
  bool IsLastFrame(int32_t frame) const {
    return input_finished_ && frame == NumFramesReady() - 1;
  }
  const float *GetFrame(int32_t frame) const { return features_.At(frame); }
  // This would be called from the application, when you get
  // more wave data.  Note: the sampling_rate is only provided so
  // the code can assert that it matches the sampling rate
  // expected in the options.
  //
  // @param sampling_rate The sampling_rate of the input waveform
  // @param waveform Pointer to a 1-D array of size n
  // @param n Number of entries in waveform
  void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
  // InputFinished() tells the class you won't be providing any
  // more waveform.  This will help flush out the last frame or two
  // of features, in the case where snip-edges == false; it also
  // affects the return value of IsLastFrame().
  void InputFinished();
 private:
  // This function computes any additional feature frames that it is possible to
  // compute from 'waveform_remainder_', which at this point may contain more
  // than just a remainder-sized quantity (because AcceptWaveform() appends to
  // waveform_remainder_ before calling this function).  It adds these feature
  // frames to features_, and shifts off any now-unneeded samples of input from
  // waveform_remainder_ while incrementing waveform_offset_ by the same amount.
  void ComputeFeatures();
  C computer_;  // class that does the MFCC or PLP or filterbank computation
  FeatureWindowFunction window_function_;
  // features_ is the Mfcc or Plp or Fbank features that we have already
  // computed.
  RecyclingVector features_;
  // True if the user has called "InputFinished()"
  bool input_finished_;
  // waveform_offset_ is the number of samples of waveform that we have
  // already discarded, i.e. that were prior to 'waveform_remainder_'.
  int64_t waveform_offset_;
  // waveform_remainder_ is a short piece of waveform that we may need to keep
  // after extracting all the whole frames we can (whatever length of feature
  // will be required for the next phase of computation).
  // It is a 1-D tensor
  std::vector<float> waveform_remainder_;
};
using OnlineFbank = OnlineGenericBaseFeature<FbankComputer>;
}  // namespace knf
#endif  // KALDI_NATIVE_FBANK_CSRC_ONLINE_FEATURE_H_
funasr/runtime/onnxruntime/src/commonfunc.h
@@ -1,5 +1,5 @@
#pragma once 
#include <algorithm>
typedef struct
{
    std::string msg;
@@ -49,3 +49,8 @@
        }
    }
}
template <class ForwardIterator>
inline static size_t argmax(ForwardIterator first, ForwardIterator last) {
    return std::distance(first, std::max_element(first, last));
}
funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -133,6 +133,10 @@
            if (fnCallback)
                fnCallback(nStep, nTotal);
        }
        if(true){
            string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
            pResult->msg = punc_res;
        }
    
        return pResult;
    }
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -17,6 +17,11 @@
        vadHandle->init_vad(vad_path, mvn_path, MODEL_SAMPLE_RATE, VAD_MAX_LEN, VAD_SILENCE_DYRATION, VAD_SPEECH_NOISE_THRES);
    }
    // PUNC model
    if(true){
        puncHandle = make_unique<CTTransformer>(path, nNumThread);
    }
    if(quantize)
    {
        model_path = pathAppend(path, "model_quant.onnx");
@@ -50,6 +55,7 @@
    m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#endif
    vector<string> m_strInputNames, m_strOutputNames;
    string strName;
    getInputName(m_session.get(), strName);
    m_strInputNames.push_back(strName.c_str());
@@ -81,6 +87,10 @@
vector<std::vector<int>> ModelImp::vad_seg(std::vector<float>& pcm_data){
    return vadHandle->infer(pcm_data);
}
string ModelImp::AddPunc(const char* szInput){
    return puncHandle->AddPunc(szInput);
}
vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) {
@@ -231,9 +241,9 @@
        auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
        result = greedy_search(floatData, *encoder_out_lens);
    }
    catch (...)
    catch (std::exception const &e)
    {
        result = "";
        printf(e.what());
    }
    return result;
funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -14,6 +14,7 @@
        knf::FbankOptions fbank_opts;
        std::unique_ptr<FsmnVad> vadHandle;
        std::unique_ptr<CTTransformer> puncHandle;
        Vocab* vocab;
        vector<float> means_list;
@@ -32,7 +33,6 @@
        Ort::Env env_;
        Ort::SessionOptions sessionOptions;
        vector<string> m_strInputNames, m_strOutputNames;
        vector<const char*> m_szInputNames;
        vector<const char*> m_szOutputNames;
@@ -45,6 +45,7 @@
        string forward(float* din, int len, int flag);
        string rescoring();
        std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data);
        string AddPunc(const char* szInput);
    };
funasr/runtime/onnxruntime/src/precomp.h
@@ -28,6 +28,8 @@
#include "ComDefine.h"
#include "commonfunc.h"
#include "predefine_coe.h"
#include "tokenizer.h"
#include "punc_infer.h"
#include "FsmnVad.h"
#include "e2e_vad.h"
#include "Vocab.h"
funasr/runtime/onnxruntime/src/punc_infer.cpp
New file
@@ -0,0 +1,183 @@
#include "precomp.h"
CTTransformer::CTTransformer(const char* sz_model_dir, int thread_num)
:env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options{}
{
    session_options.SetIntraOpNumThreads(thread_num);
    session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
    session_options.DisableCpuMemArena();
    string strModelPath = pathAppend(sz_model_dir, PUNC_MODEL_FILE);
    string strYamlPath = pathAppend(sz_model_dir, PUNC_YAML_FILE);
#ifdef _WIN32
    std::wstring detPath = strToWstr(strModelPath);
    m_session = std::make_unique<Ort::Session>(env_, detPath.c_str(), session_options);
#else
    m_session = std::make_unique<Ort::Session>(env_, strModelPath.c_str(), session_options);
#endif
    // read inputnames outputnames
    vector<string> m_strInputNames, m_strOutputNames;
    string strName;
    getInputName(m_session.get(), strName);
    m_strInputNames.push_back(strName.c_str());
    getInputName(m_session.get(), strName, 1);
    m_strInputNames.push_back(strName);
    getOutputName(m_session.get(), strName);
    m_strOutputNames.push_back(strName);
    for (auto& item : m_strInputNames)
        m_szInputNames.push_back(item.c_str());
    for (auto& item : m_strOutputNames)
        m_szOutputNames.push_back(item.c_str());
    m_tokenizer.OpenYaml(strYamlPath.c_str());
}
CTTransformer::~CTTransformer()
{
}
string CTTransformer::AddPunc(const char* sz_input)
{
    string strResult;
    vector<string> strOut;
    vector<int> InputData;
    m_tokenizer.Tokenize(sz_input, strOut, InputData);
    int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN);
    int nCurBatch = -1;
    int nSentEnd = -1, nLastCommaIndex = -1;
    vector<int64_t> RemainIDs; //
    vector<string> RemainStr; //
    vector<int> NewPunctuation; //
    vector<string> NewString; //
    vector<string> NewSentenceOut;
    vector<int> NewPuncOut;
    int nDiff = 0;
    for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN)
    {
        nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size());
        vector<int64_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff);
        vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff);
        InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs;
        InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr;
        auto Punction = Infer(InputIDs);
        nCurBatch = i / TOKEN_LEN;
        if (nCurBatch < nTotalBatch - 1) // not the last minisetence
        {
            nSentEnd = -1;
            nLastCommaIndex = -1;
            for (int nIndex = Punction.size() - 2; nIndex > 0; nIndex--)
            {
                if (m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(PERIOD_INDEX) || m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(QUESTION_INDEX))
                {
                    nSentEnd = nIndex;
                    break;
                }
                if (nLastCommaIndex < 0 && m_tokenizer.ID2Punc(Punction[nIndex]) == m_tokenizer.ID2Punc(COMMA_INDEX))
                {
                    nLastCommaIndex = nIndex;
                }
            }
            if (nSentEnd < 0 && InputStr.size() > CACHE_POP_TRIGGER_LIMIT && nLastCommaIndex > 0)
            {
                nSentEnd = nLastCommaIndex;
                Punction[nSentEnd] = PERIOD_INDEX;
            }
            RemainStr.assign(InputStr.begin() + nSentEnd + 1, InputStr.end());
            RemainIDs.assign(InputIDs.begin() + nSentEnd + 1, InputIDs.end());
            InputStr.assign(InputStr.begin(), InputStr.begin() + nSentEnd + 1);  // minit_sentence
            Punction.assign(Punction.begin(), Punction.begin() + nSentEnd + 1);
        }
        NewPunctuation.insert(NewPunctuation.end(), Punction.begin(), Punction.end());
        vector<string> WordWithPunc;
        for (int i = 0; i < InputStr.size(); i++)
        {
            if (i > 0 && !(InputStr[i][0] & 0x80) && (i + 1) <InputStr.size() && !(InputStr[i+1][0] & 0x80))// �м��Ӣ�ģ�
            {
                InputStr[i] = InputStr[i]+ " ";
            }
            WordWithPunc.push_back(InputStr[i]);
            if (Punction[i] != NOTPUNC_INDEX) // �»���
            {
                WordWithPunc.push_back(m_tokenizer.ID2Punc(Punction[i]));
            }
        }
        NewString.insert(NewString.end(), WordWithPunc.begin(), WordWithPunc.end()); // new_mini_sentence += "".join(words_with_punc)
        NewSentenceOut = NewString;
        NewPuncOut = NewPunctuation;
        // last mini sentence
        if(nCurBatch == nTotalBatch - 1)
        {
            if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(COMMA_INDEX) || NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(DUN_INDEX))
            {
                NewSentenceOut.assign(NewString.begin(), NewString.end() - 1);
                NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX));
                NewPuncOut.assign(NewPunctuation.begin(), NewPunctuation.end() - 1);
                NewPuncOut.push_back(PERIOD_INDEX);
            }
            else if (NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(PERIOD_INDEX) && NewString[NewString.size() - 1] == m_tokenizer.ID2Punc(QUESTION_INDEX))
            {
                NewSentenceOut = NewString;
                NewSentenceOut.push_back(m_tokenizer.ID2Punc(PERIOD_INDEX));
                NewPuncOut = NewPunctuation;
                NewPuncOut.push_back(PERIOD_INDEX);
            }
        }
    }
    for (auto& item : NewSentenceOut)
        strResult += item;
    return strResult;
}
vector<int> CTTransformer::Infer(vector<int64_t> input_data)
{
    Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    vector<int> punction;
    std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()};
    Ort::Value onnx_input = Ort::Value::CreateTensor<int64_t>(m_memoryInfo,
        input_data.data(),
        input_data.size(),
        input_shape_.data(),
        input_shape_.size());
    std::array<int32_t,1> text_lengths{ (int32_t)input_data.size() };
    std::array<int64_t,1> text_lengths_dim{ 1 };
    Ort::Value onnx_text_lengths = Ort::Value::CreateTensor(
        m_memoryInfo,
        text_lengths.data(),
        text_lengths.size() * sizeof(int32_t),
        text_lengths_dim.data(),
        text_lengths_dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
    std::vector<Ort::Value> input_onnx;
    input_onnx.emplace_back(std::move(onnx_input));
    input_onnx.emplace_back(std::move(onnx_text_lengths));
    try {
        auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
        std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
        int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
        float * floatData = outputTensor[0].GetTensorMutableData<float>();
        for (int i = 0; i < outputCount; i += CANDIDATE_NUM)
        {
            int index = argmax(floatData + i, floatData + i + CANDIDATE_NUM-1);
            punction.push_back(index);
        }
    }
    catch (std::exception const &e)
    {
        printf(e.what());
    }
    return punction;
}
funasr/runtime/onnxruntime/src/punc_infer.h
New file
@@ -0,0 +1,25 @@
#pragma once
class CTTransformer {
/**
 * Author: Speech Lab of DAMO Academy, Alibaba Group
 * CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
 * https://arxiv.org/pdf/2003.01309.pdf
*/
private:
    CTokenizer m_tokenizer;
    vector<const char*> m_szInputNames;
    vector<const char*> m_szOutputNames;
    std::shared_ptr<Ort::Session> m_session;
    Ort::Env env_;
    Ort::SessionOptions session_options;
public:
    CTTransformer(const char* sz_model_dir, int thread_num);
    ~CTTransformer();
    vector<int>  Infer(vector<int64_t> input_data);
    string AddPunc(const char* sz_input);
};
funasr/runtime/onnxruntime/src/tokenizer.cpp
New file
@@ -0,0 +1,208 @@
 #include "precomp.h"
CTokenizer::CTokenizer(const char* szYmlFile):m_Ready(false)
{
    OpenYaml(szYmlFile);
}
CTokenizer::CTokenizer():m_Ready(false)
{
}
void CTokenizer::read_yml(const YAML::Node& node)
{
    if (node.IsMap())
    {//��map��
        for (auto it = node.begin(); it != node.end(); ++it)
        {
            read_yml(it->second);
        }
    }
    if (node.IsSequence()) {//��������
        for (size_t i = 0; i < node.size(); ++i) {
            read_yml(node[i]);
        }
    }
    if (node.IsScalar()) {//�DZ�����
        cout << node.as<string>() << endl;
    }
}
bool CTokenizer::OpenYaml(const char* szYmlFile)
{
    YAML::Node m_Config = YAML::LoadFile(szYmlFile);
    if (m_Config.IsNull())
        return false;
    try
    {
        auto Tokens = m_Config["token_list"];
        if (Tokens.IsSequence())
        {
            for (size_t i = 0; i < Tokens.size(); ++i)
            {
                if (Tokens[i].IsScalar())
                {
                    m_ID2Token.push_back(Tokens[i].as<string>());
                    m_Token2ID.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
                }
            }
        }
        auto Puncs = m_Config["punc_list"];
        if (Puncs.IsSequence())
        {
            for (size_t i = 0; i < Puncs.size(); ++i)
            {
                if (Puncs[i].IsScalar())
                {
                    m_ID2Punc.push_back(Puncs[i].as<string>());
                    m_Punc2ID.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
                }
            }
        }
    }
    catch (YAML::BadFile& e) {
        std::cout << "read error!" << std::endl;
        return  false;
    }
    m_Ready = true;
    return m_Ready;
}
vector<string> CTokenizer::ID2String(vector<int> Input)
{
    vector<string> result;
    for (auto& item : Input)
    {
        result.push_back(m_ID2Token[item]);
    }
    return result;
}
int CTokenizer::String2ID(string Input)
{
    int nID = 0; // <blank>
    if (m_Token2ID.find(Input) != m_Token2ID.end())
        nID=(m_Token2ID[Input]);
    else
        nID=(m_Token2ID[UNK_CHAR]);
    return nID;
}
vector<int> CTokenizer::String2IDs(vector<string> Input)
{
    vector<int> result;
    for (auto& item : Input)
    {
        transform(item.begin(), item.end(), item.begin(), ::tolower);
        if (m_Token2ID.find(item) != m_Token2ID.end())
            result.push_back(m_Token2ID[item]);
        else
            result.push_back(m_Token2ID[UNK_CHAR]);
    }
    return result;
}
vector<string> CTokenizer::ID2Punc(vector<int> Input)
{
    vector<string> result;
    for (auto& item : Input)
    {
        result.push_back(m_ID2Punc[item]);
    }
    return result;
}
string CTokenizer::ID2Punc(int nPuncID)
{
    return m_ID2Punc[nPuncID];
}
vector<int> CTokenizer::Punc2IDs(vector<string> Input)
{
    vector<int> result;
    for (auto& item : Input)
    {
        result.push_back(m_Punc2ID[item]);
    }
    return result;
}
vector<string> CTokenizer::SplitChineseString(const string & strInfo)
{
    vector<string> list;
    int strSize = strInfo.size();
    int i = 0;
    while (i < strSize) {
        int len = 1;
        for (int j = 0; j < 6 && (strInfo[i] & (0x80 >> j)); j++) {
            len = j + 1;
        }
        list.push_back(strInfo.substr(i, len));
        i += len;
    }
    return list;
}
void CTokenizer::strSplit(const string& str, const char split, vector<string>& res)
{
    if (str == "")
    {
        return;
    }
    string&& strs = str + split;
    size_t pos = strs.find(split);
    while (pos != string::npos)
    {
        res.emplace_back(strs.substr(0, pos));
        strs = move(strs.substr(pos + 1, strs.size()));
        pos = strs.find(split);
    }
}
 void CTokenizer::Tokenize(const char* strInfo, vector<string> & strOut, vector<int> & IDOut)
{
    vector<string>  strList;
    strSplit(strInfo,' ', strList);
    string current_eng,current_chinese;
    for (auto& item : strList)
    {
        current_eng = "";
        current_chinese = "";
        for (auto& ch : item)
        {
            if (!(ch& 0x80))
            { // Ӣ��
                if (current_chinese.size() > 0)
                {
                    // for utf-8 chinese
                    auto chineseList = SplitChineseString(current_chinese);
                    strOut.insert(strOut.end(), chineseList.begin(),chineseList.end());
                    current_chinese = "";
                }
                current_eng += ch;
            }
            else
            {
                if (current_eng.size() > 0)
                {
                    strOut.push_back(current_eng);
                    current_eng = "";
                }
                current_chinese += ch;
            }
        }
        if (current_chinese.size() > 0)
        {
            auto chineseList = SplitChineseString(current_chinese);
            strOut.insert(strOut.end(), chineseList.begin(), chineseList.end());
            current_chinese = "";
        }
        if (current_eng.size() > 0)
        {
            strOut.push_back(current_eng);
        }
    }
    IDOut= String2IDs(strOut);
}
funasr/runtime/onnxruntime/src/tokenizer.h
New file
@@ -0,0 +1,27 @@
#pragma once
#include "yaml-cpp/yaml.h"
class CTokenizer {
private:
    bool  m_Ready = false;
    vector<string>   m_ID2Token,m_ID2Punc;
    map<string, int>  m_Token2ID,m_Punc2ID;
public:
    CTokenizer(const char* szYmlFile);
    CTokenizer();
    bool OpenYaml(const char* szYmlFile);
    void read_yml(const YAML::Node& node);
    vector<string> ID2String(vector<int> Input);
    vector<int> String2IDs(vector<string> Input);
    int String2ID(string Input);
    vector<string> ID2Punc(vector<int> Input);
    string ID2Punc(int nPuncID);
    vector<int> Punc2IDs(vector<string> Input);
    vector<string> SplitChineseString(const string& strInfo);
    void strSplit(const string& str, const char split, vector<string>& res);
    void Tokenize(const char* strInfo, vector<string>& strOut, vector<int>& IDOut);
};