游雁
2023-10-19 b9bcf1f093c3053fdc4e2cf4a1d38e27bbf429fb
funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -1,40 +1,95 @@
 #include "precomp.h"
 /**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
CTokenizer::CTokenizer(const char* szYmlFile):m_Ready(false)
#include "precomp.h"
namespace funasr {
CTokenizer::CTokenizer(const char* sz_yamlfile):m_ready(false)
{
   OpenYaml(szYmlFile);
   OpenYaml(sz_yamlfile);
}
CTokenizer::CTokenizer():m_Ready(false)
CTokenizer::CTokenizer():m_ready(false)
{
}
void CTokenizer::read_yml(const YAML::Node& node)
CTokenizer::~CTokenizer()
{
   delete jieba_dict_trie_;
    delete jieba_model_;
}
void CTokenizer::SetJiebaRes(cppjieba::DictTrie *dict, cppjieba::HMMModel *hmm) {
   jieba_processor_.SetJiebaRes(dict, hmm);
}
void CTokenizer::JiebaInit(std::string punc_config){
    if (seg_jieba){
        std::string model_path = punc_config.substr(0, punc_config.length() - (sizeof(PUNC_CONFIG_NAME)-1));
        std::string jieba_dict_file = PathAppend(model_path, JIEBA_DICT);
        std::string jieba_hmm_file = PathAppend(model_path, JIEBA_HMM_MODEL);
        std::string jieba_userdict_file = PathAppend(model_path, JIEBA_USERDICT);
      try{
           jieba_dict_trie_ = new cppjieba::DictTrie(jieba_dict_file, jieba_userdict_file);
         LOG(INFO) << "Successfully load file from " << jieba_dict_file << ", " << jieba_userdict_file;
      }catch(exception const &e){
         LOG(ERROR) << "Error loading file, Jieba dict file error or not exist.";
         exit(-1);
      }
      try{
           jieba_model_ = new cppjieba::HMMModel(jieba_hmm_file);
         LOG(INFO) << "Successfully load model from " << jieba_hmm_file;
      }catch(exception const &e){
         LOG(ERROR) << "Error loading file, Jieba hmm file error or not exist.";
         exit(-1);
      }
        SetJiebaRes(jieba_dict_trie_, jieba_model_);
    }else {
        jieba_dict_trie_ = NULL;
        jieba_model_ = NULL;
    }
}
void CTokenizer::ReadYaml(const YAML::Node& node)
{
   if (node.IsMap()) 
   {//��map��
      for (auto it = node.begin(); it != node.end(); ++it) 
      {
         read_yml(it->second);
         ReadYaml(it->second);
      }
   }
   if (node.IsSequence()) {//��������
      for (size_t i = 0; i < node.size(); ++i) {
         read_yml(node[i]);
         ReadYaml(node[i]);
      }
   }
   if (node.IsScalar()) {//�DZ�����
      cout << node.as<string>() << endl;
      LOG(INFO) << node.as<string>();
   }
}
bool CTokenizer::OpenYaml(const char* szYmlFile)
bool CTokenizer::OpenYaml(const char* sz_yamlfile)
{
   YAML::Node m_Config = YAML::LoadFile(szYmlFile);
   if (m_Config.IsNull())
      return false;
   YAML::Node m_Config;
   try{
      m_Config = YAML::LoadFile(sz_yamlfile);
   }catch(exception const &e){
        LOG(INFO) << "Error loading file, yaml file error or not exist.";
        exit(-1);
    }
   try
   {
      YAML::Node conf_seg_jieba = m_Config["seg_jieba"];
        if (conf_seg_jieba.IsDefined()){
            seg_jieba = conf_seg_jieba.as<bool>();
        }
      auto Tokens = m_Config["token_list"];
      if (Tokens.IsSequence())
      {
@@ -42,8 +97,8 @@
         {
            if (Tokens[i].IsScalar())
            {
               m_ID2Token.push_back(Tokens[i].as<string>());
               m_Token2ID.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
               m_id2token.push_back(Tokens[i].as<string>());
               m_token2id.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
            }
         }
      }
@@ -54,97 +109,113 @@
         {
            if (Puncs[i].IsScalar())
            { 
               m_ID2Punc.push_back(Puncs[i].as<string>());
               m_Punc2ID.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
               m_id2punc.push_back(Puncs[i].as<string>());
               m_punc2id.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
            }
         }
      }
   }
   catch (YAML::BadFile& e) {
      std::cout << "read error!" << std::endl;
      LOG(ERROR) << "Read error!";
      return  false;
   }
   m_Ready = true;
   return m_Ready;
   m_ready = true;
   return m_ready;
}
vector<string> CTokenizer::ID2String(vector<int> Input)
vector<string> CTokenizer::Id2String(vector<int> input)
{
   vector<string> result;
   for (auto& item : Input)
   for (auto& item : input)
   {
      result.push_back(m_ID2Token[item]);
      result.push_back(m_id2token[item]);
   }
   return result;
}
int CTokenizer::String2ID(string Input)
int CTokenizer::String2Id(string input)
{
   int nID = 0; // <blank>
   if (m_Token2ID.find(Input) != m_Token2ID.end())
      nID=(m_Token2ID[Input]);
   if (m_token2id.find(input) != m_token2id.end())
      nID=(m_token2id[input]);
   else
      nID=(m_Token2ID[UNK_CHAR]);
      nID=(m_token2id[UNK_CHAR]);
   return nID;
}
vector<int> CTokenizer::String2IDs(vector<string> Input)
vector<int> CTokenizer::String2Ids(vector<string> input)
{
   vector<int> result;
   for (auto& item : Input)
   for (auto& item : input)
   {   
      transform(item.begin(), item.end(), item.begin(), ::tolower);
      if (m_Token2ID.find(item) != m_Token2ID.end())
         result.push_back(m_Token2ID[item]);
      if (m_token2id.find(item) != m_token2id.end())
         result.push_back(m_token2id[item]);
      else
         result.push_back(m_Token2ID[UNK_CHAR]);
         result.push_back(m_token2id[UNK_CHAR]);
   }
   return result;
}
vector<string> CTokenizer::ID2Punc(vector<int> Input)
vector<string> CTokenizer::Id2Punc(vector<int> input)
{
   vector<string> result;
   for (auto& item : Input)
   for (auto& item : input)
   {
      result.push_back(m_ID2Punc[item]);
      result.push_back(m_id2punc[item]);
   }
   return result;
}
string CTokenizer::ID2Punc(int nPuncID)
string CTokenizer::Id2Punc(int n_punc_id)
{
   return m_ID2Punc[nPuncID];
   return m_id2punc[n_punc_id];
}
vector<int> CTokenizer::Punc2IDs(vector<string> Input)
vector<int> CTokenizer::Punc2Ids(vector<string> input)
{
   vector<int> result;
   for (auto& item : Input)
   for (auto& item : input)
   {
      result.push_back(m_Punc2ID[item]);
      result.push_back(m_punc2id[item]);
   }
   return result;
}
vector<string> CTokenizer::SplitChineseString(const string & strInfo)
bool CTokenizer::IsPunc(string& Punc)
{
   if (m_punc2id.find(Punc) != m_punc2id.end())
      return true;
   else
      return false;
}
vector<string> CTokenizer::SplitChineseString(const string & str_info)
{
   vector<string> list;
   int strSize = strInfo.size();
   int strSize = str_info.size();
   int i = 0;
   while (i < strSize) {
      int len = 1;
      for (int j = 0; j < 6 && (strInfo[i] & (0x80 >> j)); j++) {
      for (int j = 0; j < 6 && (str_info[i] & (0x80 >> j)); j++) {
         len = j + 1;
      }
      list.push_back(strInfo.substr(i, len));
      list.push_back(str_info.substr(i, len));
      i += len;
   }
   return list;
}
void CTokenizer::strSplit(const string& str, const char split, vector<string>& res)
vector<string> CTokenizer::SplitChineseJieba(const string & str_info)
{
   vector<string> list;
   jieba_processor_.Cut(str_info, list, false);
   return list;
}
void CTokenizer::StrSplit(const string& str, const char split, vector<string>& res)
{
   if (str == "")
   {
@@ -161,10 +232,10 @@
   }
}
 void CTokenizer::Tokenize(const char* strInfo, vector<string> & strOut, vector<int> & IDOut)
void CTokenizer::Tokenize(const char* str_info, vector<string> & str_out, vector<int> & id_out)
{
   vector<string>  strList;
   strSplit(strInfo,' ', strList);
   StrSplit(str_info,' ', strList);
   string current_eng,current_chinese;
   for (auto& item : strList)
   {
@@ -177,8 +248,13 @@
            if (current_chinese.size() > 0)
            {
               // for utf-8 chinese
               auto chineseList = SplitChineseString(current_chinese);
               strOut.insert(strOut.end(), chineseList.begin(),chineseList.end());
               vector<string> chineseList;
               if(seg_jieba){
                  chineseList = SplitChineseJieba(current_chinese);
               }else{
                  chineseList = SplitChineseString(current_chinese);
               }
               str_out.insert(str_out.end(), chineseList.begin(),chineseList.end());
               current_chinese = "";
            }
            current_eng += ch;
@@ -187,7 +263,7 @@
         {
            if (current_eng.size() > 0)
            {
               strOut.push_back(current_eng);
               str_out.push_back(current_eng);
               current_eng = "";
            }
            current_chinese += ch;
@@ -195,14 +271,22 @@
      }
      if (current_chinese.size() > 0)
      {
         auto chineseList = SplitChineseString(current_chinese);
         strOut.insert(strOut.end(), chineseList.begin(), chineseList.end());
         // for utf-8 chinese
         vector<string> chineseList;
         if(seg_jieba){
            chineseList = SplitChineseJieba(current_chinese);
         }else{
            chineseList = SplitChineseString(current_chinese);
         }
         str_out.insert(str_out.end(), chineseList.begin(), chineseList.end());
         current_chinese = "";
      }
      if (current_eng.size() > 0)
      {
         strOut.push_back(current_eng);
         str_out.push_back(current_eng);
      }
   }
   IDOut= String2IDs(strOut);
   id_out= String2Ids(str_out);
}
} // namespace funasr