From 6b0b94bdcdb40ca42e2b65dc7ff85b88876feada Mon Sep 17 00:00:00 2001
From: Xian Shi <40013335+R1ckShi@users.noreply.github.com>
Date: 星期二, 17 十月 2023 16:44:48 +0800
Subject: [PATCH] Update README.md

---
 funasr/runtime/onnxruntime/src/tokenizer.cpp |  190 ++++++++++++++++++++++++++++++++++-------------
 1 files changed, 137 insertions(+), 53 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/tokenizer.cpp b/funasr/runtime/onnxruntime/src/tokenizer.cpp
index 324def7..a111b91 100644
--- a/funasr/runtime/onnxruntime/src/tokenizer.cpp
+++ b/funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -1,40 +1,95 @@
- #include "precomp.h"
+ /**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License  (https://opensource.org/licenses/MIT)
+*/
 
-CTokenizer::CTokenizer(const char* szYmlFile):m_Ready(false)
+#include "precomp.h"
+
+namespace funasr {
+CTokenizer::CTokenizer(const char* sz_yamlfile):m_ready(false)
 {
-	OpenYaml(szYmlFile);
+	OpenYaml(sz_yamlfile);
 }
 
-CTokenizer::CTokenizer():m_Ready(false)
+CTokenizer::CTokenizer():m_ready(false)
 {
 }
 
-void CTokenizer::read_yml(const YAML::Node& node) 
+CTokenizer::~CTokenizer()
+{
+	delete jieba_dict_trie_;
+    delete jieba_model_;
+}
+
+void CTokenizer::SetJiebaRes(cppjieba::DictTrie *dict, cppjieba::HMMModel *hmm) {
+	jieba_processor_.SetJiebaRes(dict, hmm);
+}
+
+void CTokenizer::JiebaInit(std::string punc_config){
+    if (seg_jieba){
+        std::string model_path = punc_config.substr(0, punc_config.length() - (sizeof(PUNC_CONFIG_NAME)-1));
+        std::string jieba_dict_file = PathAppend(model_path, JIEBA_DICT);
+        std::string jieba_hmm_file = PathAppend(model_path, JIEBA_HMM_MODEL);
+        std::string jieba_userdict_file = PathAppend(model_path, JIEBA_USERDICT);
+		try{
+        	jieba_dict_trie_ = new cppjieba::DictTrie(jieba_dict_file, jieba_userdict_file);
+			LOG(INFO) << "Successfully load file from " << jieba_dict_file << ", " << jieba_userdict_file;
+		}catch(exception const &e){
+			LOG(ERROR) << "Error loading file, Jieba dict file error or not exist.";
+			exit(-1);
+		}
+
+		try{
+        	jieba_model_ = new cppjieba::HMMModel(jieba_hmm_file);
+			LOG(INFO) << "Successfully load model from " << jieba_hmm_file;
+		}catch(exception const &e){
+			LOG(ERROR) << "Error loading file, Jieba hmm file error or not exist.";
+			exit(-1);
+		}
+
+        SetJiebaRes(jieba_dict_trie_, jieba_model_);
+    }else {
+        jieba_dict_trie_ = NULL;
+        jieba_model_ = NULL;
+    }
+}
+
+void CTokenizer::ReadYaml(const YAML::Node& node) 
 {
 	if (node.IsMap()) 
 	{//锟斤拷map锟斤拷
 		for (auto it = node.begin(); it != node.end(); ++it) 
 		{
-			read_yml(it->second);
+			ReadYaml(it->second);
 		}
 	}
 	if (node.IsSequence()) {//锟斤拷锟斤拷锟斤拷锟斤拷
 		for (size_t i = 0; i < node.size(); ++i) {
-			read_yml(node[i]);
+			ReadYaml(node[i]);
 		}
 	}
 	if (node.IsScalar()) {//锟角憋拷锟斤拷锟斤拷
-		cout << node.as<string>() << endl;
+		LOG(INFO) << node.as<string>();
 	}
 }
 
-bool CTokenizer::OpenYaml(const char* szYmlFile)
+bool CTokenizer::OpenYaml(const char* sz_yamlfile)
 {
-	YAML::Node m_Config = YAML::LoadFile(szYmlFile);
-	if (m_Config.IsNull())
-		return false;
+	YAML::Node m_Config;
+	try{
+		m_Config = YAML::LoadFile(sz_yamlfile);
+	}catch(exception const &e){
+        LOG(INFO) << "Error loading file, yaml file error or not exist.";
+        exit(-1);
+    }
+
 	try
 	{
+		YAML::Node conf_seg_jieba = m_Config["seg_jieba"];
+        if (conf_seg_jieba.IsDefined()){
+            seg_jieba = conf_seg_jieba.as<bool>();
+        }
+
 		auto Tokens = m_Config["token_list"];
 		if (Tokens.IsSequence())
 		{
@@ -42,8 +97,8 @@
 			{
 				if (Tokens[i].IsScalar())
 				{
-					m_ID2Token.push_back(Tokens[i].as<string>());
-					m_Token2ID.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
+					m_id2token.push_back(Tokens[i].as<string>());
+					m_token2id.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
 				}
 			}
 		}
@@ -54,97 +109,113 @@
 			{
 				if (Puncs[i].IsScalar())
 				{ 
-					m_ID2Punc.push_back(Puncs[i].as<string>());
-					m_Punc2ID.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
+					m_id2punc.push_back(Puncs[i].as<string>());
+					m_punc2id.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
 				}
 			}
 		}
 	}
 	catch (YAML::BadFile& e) {
-		std::cout << "read error!" << std::endl;
+		LOG(ERROR) << "Read error!";
 		return  false;
 	}
-	m_Ready = true;
-	return m_Ready;
+	m_ready = true;
+	return m_ready;
 }
 
-vector<string> CTokenizer::ID2String(vector<int> Input)
+vector<string> CTokenizer::Id2String(vector<int> input)
 {
 	vector<string> result;
-	for (auto& item : Input)
+	for (auto& item : input)
 	{
-		result.push_back(m_ID2Token[item]);
+		result.push_back(m_id2token[item]);
 	}
 	return result;
 }
 
-int CTokenizer::String2ID(string Input)
+int CTokenizer::String2Id(string input)
 {
 	int nID = 0; // <blank>
-	if (m_Token2ID.find(Input) != m_Token2ID.end())
-		nID=(m_Token2ID[Input]);
+	if (m_token2id.find(input) != m_token2id.end())
+		nID=(m_token2id[input]);
 	else
-		nID=(m_Token2ID[UNK_CHAR]);
+		nID=(m_token2id[UNK_CHAR]);
 	return nID;
 }
 
-vector<int> CTokenizer::String2IDs(vector<string> Input)
+vector<int> CTokenizer::String2Ids(vector<string> input)
 {
 	vector<int> result;
-	for (auto& item : Input)
+	for (auto& item : input)
 	{	
 		transform(item.begin(), item.end(), item.begin(), ::tolower);
-		if (m_Token2ID.find(item) != m_Token2ID.end())
-			result.push_back(m_Token2ID[item]);
+		if (m_token2id.find(item) != m_token2id.end())
+			result.push_back(m_token2id[item]);
 		else
-			result.push_back(m_Token2ID[UNK_CHAR]);
+			result.push_back(m_token2id[UNK_CHAR]);
 	}
 	return result;
 }
 
-vector<string> CTokenizer::ID2Punc(vector<int> Input)
+vector<string> CTokenizer::Id2Punc(vector<int> input)
 {
 	vector<string> result;
-	for (auto& item : Input)
+	for (auto& item : input)
 	{
-		result.push_back(m_ID2Punc[item]);
+		result.push_back(m_id2punc[item]);
 	}
 	return result;
 }
 
-string CTokenizer::ID2Punc(int nPuncID)
+string CTokenizer::Id2Punc(int n_punc_id)
 {
-	return m_ID2Punc[nPuncID];
+	return m_id2punc[n_punc_id];
 }
 
-vector<int> CTokenizer::Punc2IDs(vector<string> Input)
+vector<int> CTokenizer::Punc2Ids(vector<string> input)
 {
 	vector<int> result;
-	for (auto& item : Input)
+	for (auto& item : input)
 	{
-		result.push_back(m_Punc2ID[item]);
+		result.push_back(m_punc2id[item]);
 	}
 	return result;
 }
 
-vector<string> CTokenizer::SplitChineseString(const string & strInfo)
+bool CTokenizer::IsPunc(string& Punc)
+{
+	if (m_punc2id.find(Punc) != m_punc2id.end())
+		return true;
+	else
+		return false;
+}
+
+vector<string> CTokenizer::SplitChineseString(const string & str_info)
 {
 	vector<string> list;
-	int strSize = strInfo.size();
+	int strSize = str_info.size();
 	int i = 0;
 
 	while (i < strSize) {
 		int len = 1;
-		for (int j = 0; j < 6 && (strInfo[i] & (0x80 >> j)); j++) {
+		for (int j = 0; j < 6 && (str_info[i] & (0x80 >> j)); j++) {
 			len = j + 1;
 		}
-		list.push_back(strInfo.substr(i, len));
+		list.push_back(str_info.substr(i, len));
 		i += len;
 	}
 	return list;
 }
 
-void CTokenizer::strSplit(const string& str, const char split, vector<string>& res)
+vector<string> CTokenizer::SplitChineseJieba(const string & str_info)
+{
+	vector<string> list;
+	jieba_processor_.Cut(str_info, list, false);
+
+	return list;
+}
+
+void CTokenizer::StrSplit(const string& str, const char split, vector<string>& res)
 {
 	if (str == "")
 	{
@@ -161,10 +232,10 @@
 	}
 }
 
- void CTokenizer::Tokenize(const char* strInfo, vector<string> & strOut, vector<int> & IDOut)
+void CTokenizer::Tokenize(const char* str_info, vector<string> & str_out, vector<int> & id_out)
 {
 	vector<string>  strList;
-	strSplit(strInfo,' ', strList);
+	StrSplit(str_info,' ', strList);
 	string current_eng,current_chinese;
 	for (auto& item : strList)
 	{
@@ -177,8 +248,13 @@
 				if (current_chinese.size() > 0)
 				{
 					// for utf-8 chinese
-					auto chineseList = SplitChineseString(current_chinese);
-					strOut.insert(strOut.end(), chineseList.begin(),chineseList.end());
+					vector<string> chineseList;
+					if(seg_jieba){
+						chineseList = SplitChineseJieba(current_chinese);
+					}else{
+						chineseList = SplitChineseString(current_chinese);
+					}
+					str_out.insert(str_out.end(), chineseList.begin(),chineseList.end());
 					current_chinese = "";
 				}
 				current_eng += ch;
@@ -187,7 +263,7 @@
 			{
 				if (current_eng.size() > 0)
 				{
-					strOut.push_back(current_eng);
+					str_out.push_back(current_eng);
 					current_eng = "";
 				}
 				current_chinese += ch;
@@ -195,14 +271,22 @@
 		}
 		if (current_chinese.size() > 0)
 		{
-			auto chineseList = SplitChineseString(current_chinese);
-			strOut.insert(strOut.end(), chineseList.begin(), chineseList.end());
+			// for utf-8 chinese
+			vector<string> chineseList;
+			if(seg_jieba){
+				chineseList = SplitChineseJieba(current_chinese);
+			}else{
+				chineseList = SplitChineseString(current_chinese);
+			}
+			str_out.insert(str_out.end(), chineseList.begin(), chineseList.end());
 			current_chinese = "";
 		}
 		if (current_eng.size() > 0)
 		{
-			strOut.push_back(current_eng);
+			str_out.push_back(current_eng);
 		}
 	}
-	IDOut= String2IDs(strOut);
+	id_out= String2Ids(str_out);
 }
+
+} // namespace funasr
\ No newline at end of file

--
Gitblit v1.9.1