From 91231a03f5c16fff0d9d54f859c7a9aa02fd239c Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期一, 16 十月 2023 14:47:17 +0800
Subject: [PATCH] add jieba for ct-transformer
---
funasr/runtime/onnxruntime/src/tokenizer.cpp | 65 +++++++++++++++++++++++++++++++-
1 files changed, 62 insertions(+), 3 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.cpp b/funasr/runtime/onnxruntime/src/tokenizer.cpp
index cd3f027..a111b91 100644
--- a/funasr/runtime/onnxruntime/src/tokenizer.cpp
+++ b/funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -17,6 +17,41 @@
CTokenizer::~CTokenizer()
{
+ delete jieba_dict_trie_;
+ delete jieba_model_;
+}
+
+void CTokenizer::SetJiebaRes(cppjieba::DictTrie *dict, cppjieba::HMMModel *hmm) {
+ jieba_processor_.SetJiebaRes(dict, hmm);
+}
+
+void CTokenizer::JiebaInit(std::string punc_config){
+ if (seg_jieba){
+ std::string model_path = punc_config.substr(0, punc_config.length() - (sizeof(PUNC_CONFIG_NAME)-1));
+ std::string jieba_dict_file = PathAppend(model_path, JIEBA_DICT);
+ std::string jieba_hmm_file = PathAppend(model_path, JIEBA_HMM_MODEL);
+ std::string jieba_userdict_file = PathAppend(model_path, JIEBA_USERDICT);
+ try{
+ jieba_dict_trie_ = new cppjieba::DictTrie(jieba_dict_file, jieba_userdict_file);
+ LOG(INFO) << "Successfully load file from " << jieba_dict_file << ", " << jieba_userdict_file;
+ }catch(exception const &e){
+ LOG(ERROR) << "Error loading file, Jieba dict file error or not exist.";
+ exit(-1);
+ }
+
+ try{
+ jieba_model_ = new cppjieba::HMMModel(jieba_hmm_file);
+ LOG(INFO) << "Successfully load model from " << jieba_hmm_file;
+ }catch(exception const &e){
+ LOG(ERROR) << "Error loading file, Jieba hmm file error or not exist.";
+ exit(-1);
+ }
+
+ SetJiebaRes(jieba_dict_trie_, jieba_model_);
+ }else {
+ jieba_dict_trie_ = NULL;
+ jieba_model_ = NULL;
+ }
}
void CTokenizer::ReadYaml(const YAML::Node& node)
@@ -50,6 +85,11 @@
try
{
+ YAML::Node conf_seg_jieba = m_Config["seg_jieba"];
+ if (conf_seg_jieba.IsDefined()){
+ seg_jieba = conf_seg_jieba.as<bool>();
+ }
+
auto Tokens = m_Config["token_list"];
if (Tokens.IsSequence())
{
@@ -167,6 +207,14 @@
return list;
}
+vector<string> CTokenizer::SplitChineseJieba(const string & str_info)
+{
+ vector<string> list;
+ jieba_processor_.Cut(str_info, list, false);
+
+ return list;
+}
+
void CTokenizer::StrSplit(const string& str, const char split, vector<string>& res)
{
if (str == "")
@@ -184,7 +232,7 @@
}
}
- void CTokenizer::Tokenize(const char* str_info, vector<string> & str_out, vector<int> & id_out)
+void CTokenizer::Tokenize(const char* str_info, vector<string> & str_out, vector<int> & id_out)
{
vector<string> strList;
StrSplit(str_info,' ', strList);
@@ -200,7 +248,12 @@
if (current_chinese.size() > 0)
{
// for utf-8 chinese
- auto chineseList = SplitChineseString(current_chinese);
+ vector<string> chineseList;
+ if(seg_jieba){
+ chineseList = SplitChineseJieba(current_chinese);
+ }else{
+ chineseList = SplitChineseString(current_chinese);
+ }
str_out.insert(str_out.end(), chineseList.begin(),chineseList.end());
current_chinese = "";
}
@@ -218,7 +271,13 @@
}
if (current_chinese.size() > 0)
{
- auto chineseList = SplitChineseString(current_chinese);
+ // for utf-8 chinese
+ vector<string> chineseList;
+ if(seg_jieba){
+ chineseList = SplitChineseJieba(current_chinese);
+ }else{
+ chineseList = SplitChineseString(current_chinese);
+ }
str_out.insert(str_out.end(), chineseList.begin(), chineseList.end());
current_chinese = "";
}
--
Gitblit v1.9.1