From 91231a03f5c16fff0d9d54f859c7a9aa02fd239c Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期一, 16 十月 2023 14:47:17 +0800
Subject: [PATCH] add jieba for ct-transformer

---
 funasr/runtime/onnxruntime/src/tokenizer.cpp |   65 +++++++++++++++++++++++++++++++-
 1 files changed, 62 insertions(+), 3 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/tokenizer.cpp b/funasr/runtime/onnxruntime/src/tokenizer.cpp
index cd3f027..a111b91 100644
--- a/funasr/runtime/onnxruntime/src/tokenizer.cpp
+++ b/funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -17,6 +17,41 @@
 
 CTokenizer::~CTokenizer()
 {
+	delete jieba_dict_trie_;
+    delete jieba_model_;
+}
+
+void CTokenizer::SetJiebaRes(cppjieba::DictTrie *dict, cppjieba::HMMModel *hmm) {
+	jieba_processor_.SetJiebaRes(dict, hmm);
+}
+
+void CTokenizer::JiebaInit(std::string punc_config){
+    if (seg_jieba){
+        std::string model_path = punc_config.substr(0, punc_config.length() - (sizeof(PUNC_CONFIG_NAME)-1));
+        std::string jieba_dict_file = PathAppend(model_path, JIEBA_DICT);
+        std::string jieba_hmm_file = PathAppend(model_path, JIEBA_HMM_MODEL);
+        std::string jieba_userdict_file = PathAppend(model_path, JIEBA_USERDICT);
+		try{
+        	jieba_dict_trie_ = new cppjieba::DictTrie(jieba_dict_file, jieba_userdict_file);
+			LOG(INFO) << "Successfully load file from " << jieba_dict_file << ", " << jieba_userdict_file;
+		}catch(exception const &e){
+			LOG(ERROR) << "Error loading file, Jieba dict file error or not exist.";
+			exit(-1);
+		}
+
+		try{
+        	jieba_model_ = new cppjieba::HMMModel(jieba_hmm_file);
+			LOG(INFO) << "Successfully load model from " << jieba_hmm_file;
+		}catch(exception const &e){
+			LOG(ERROR) << "Error loading file, Jieba hmm file error or not exist.";
+			exit(-1);
+		}
+
+        SetJiebaRes(jieba_dict_trie_, jieba_model_);
+    }else {
+        jieba_dict_trie_ = NULL;
+        jieba_model_ = NULL;
+    }
 }
 
 void CTokenizer::ReadYaml(const YAML::Node& node) 
@@ -50,6 +85,11 @@
 
 	try
 	{
+		YAML::Node conf_seg_jieba = m_Config["seg_jieba"];
+        if (conf_seg_jieba.IsDefined()){
+            seg_jieba = conf_seg_jieba.as<bool>();
+        }
+
 		auto Tokens = m_Config["token_list"];
 		if (Tokens.IsSequence())
 		{
@@ -167,6 +207,14 @@
 	return list;
 }
 
+vector<string> CTokenizer::SplitChineseJieba(const string & str_info)
+{
+	vector<string> list;
+	jieba_processor_.Cut(str_info, list, false);
+
+	return list;
+}
+
 void CTokenizer::StrSplit(const string& str, const char split, vector<string>& res)
 {
 	if (str == "")
@@ -184,7 +232,7 @@
 	}
 }
 
- void CTokenizer::Tokenize(const char* str_info, vector<string> & str_out, vector<int> & id_out)
+void CTokenizer::Tokenize(const char* str_info, vector<string> & str_out, vector<int> & id_out)
 {
 	vector<string>  strList;
 	StrSplit(str_info,' ', strList);
@@ -200,7 +248,12 @@
 				if (current_chinese.size() > 0)
 				{
 					// for utf-8 chinese
-					auto chineseList = SplitChineseString(current_chinese);
+					vector<string> chineseList;
+					if(seg_jieba){
+						chineseList = SplitChineseJieba(current_chinese);
+					}else{
+						chineseList = SplitChineseString(current_chinese);
+					}
 					str_out.insert(str_out.end(), chineseList.begin(),chineseList.end());
 					current_chinese = "";
 				}
@@ -218,7 +271,13 @@
 		}
 		if (current_chinese.size() > 0)
 		{
-			auto chineseList = SplitChineseString(current_chinese);
+			// for utf-8 chinese
+			vector<string> chineseList;
+			if(seg_jieba){
+				chineseList = SplitChineseJieba(current_chinese);
+			}else{
+				chineseList = SplitChineseString(current_chinese);
+			}
 			str_out.insert(str_out.end(), chineseList.begin(), chineseList.end());
 			current_chinese = "";
 		}

--
Gitblit v1.9.1