雾聪
2023-10-16 91231a03f5c16fff0d9d54f859c7a9aa02fd239c
funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -17,6 +17,41 @@
CTokenizer::~CTokenizer()
{
   delete jieba_dict_trie_;
    delete jieba_model_;
}
void CTokenizer::SetJiebaRes(cppjieba::DictTrie *dict, cppjieba::HMMModel *hmm) {
   jieba_processor_.SetJiebaRes(dict, hmm);
}
void CTokenizer::JiebaInit(std::string punc_config){
    if (seg_jieba){
        std::string model_path = punc_config.substr(0, punc_config.length() - (sizeof(PUNC_CONFIG_NAME)-1));
        std::string jieba_dict_file = PathAppend(model_path, JIEBA_DICT);
        std::string jieba_hmm_file = PathAppend(model_path, JIEBA_HMM_MODEL);
        std::string jieba_userdict_file = PathAppend(model_path, JIEBA_USERDICT);
      try{
           jieba_dict_trie_ = new cppjieba::DictTrie(jieba_dict_file, jieba_userdict_file);
         LOG(INFO) << "Successfully load file from " << jieba_dict_file << ", " << jieba_userdict_file;
      }catch(exception const &e){
         LOG(ERROR) << "Error loading file, Jieba dict file error or not exist.";
         exit(-1);
      }
      try{
           jieba_model_ = new cppjieba::HMMModel(jieba_hmm_file);
         LOG(INFO) << "Successfully load model from " << jieba_hmm_file;
      }catch(exception const &e){
         LOG(ERROR) << "Error loading file, Jieba hmm file error or not exist.";
         exit(-1);
      }
        SetJiebaRes(jieba_dict_trie_, jieba_model_);
    }else {
        jieba_dict_trie_ = NULL;
        jieba_model_ = NULL;
    }
}
void CTokenizer::ReadYaml(const YAML::Node& node) 
@@ -50,6 +85,11 @@
   try
   {
      YAML::Node conf_seg_jieba = m_Config["seg_jieba"];
        if (conf_seg_jieba.IsDefined()){
            seg_jieba = conf_seg_jieba.as<bool>();
        }
      auto Tokens = m_Config["token_list"];
      if (Tokens.IsSequence())
      {
@@ -167,6 +207,14 @@
   return list;
}
vector<string> CTokenizer::SplitChineseJieba(const string & str_info)
{
   vector<string> list;
   jieba_processor_.Cut(str_info, list, false);
   return list;
}
void CTokenizer::StrSplit(const string& str, const char split, vector<string>& res)
{
   if (str == "")
@@ -184,7 +232,7 @@
   }
}
 void CTokenizer::Tokenize(const char* str_info, vector<string> & str_out, vector<int> & id_out)
void CTokenizer::Tokenize(const char* str_info, vector<string> & str_out, vector<int> & id_out)
{
   vector<string>  strList;
   StrSplit(str_info,' ', strList);
@@ -200,7 +248,12 @@
            if (current_chinese.size() > 0)
            {
               // for utf-8 chinese
               auto chineseList = SplitChineseString(current_chinese);
               vector<string> chineseList;
               if(seg_jieba){
                  chineseList = SplitChineseJieba(current_chinese);
               }else{
                  chineseList = SplitChineseString(current_chinese);
               }
               str_out.insert(str_out.end(), chineseList.begin(),chineseList.end());
               current_chinese = "";
            }
@@ -218,7 +271,13 @@
      }
      if (current_chinese.size() > 0)
      {
         auto chineseList = SplitChineseString(current_chinese);
         // for utf-8 chinese
         vector<string> chineseList;
         if(seg_jieba){
            chineseList = SplitChineseJieba(current_chinese);
         }else{
            chineseList = SplitChineseString(current_chinese);
         }
         str_out.insert(str_out.end(), chineseList.begin(), chineseList.end());
         current_chinese = "";
      }