From bc723ea200144bd6fa8a5dff4b9a780feda144fc Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 29 六月 2023 18:55:01 +0800
Subject: [PATCH] dcos
---
funasr/runtime/onnxruntime/src/tokenizer.cpp | 127 +++++++++++++++++++++++++-----------------
1 files changed, 76 insertions(+), 51 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/tokenizer.cpp b/funasr/runtime/onnxruntime/src/tokenizer.cpp
index 324def7..cd3f027 100644
--- a/funasr/runtime/onnxruntime/src/tokenizer.cpp
+++ b/funasr/runtime/onnxruntime/src/tokenizer.cpp
@@ -1,38 +1,53 @@
- #include "precomp.h"
+ /**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
-CTokenizer::CTokenizer(const char* szYmlFile):m_Ready(false)
+#include "precomp.h"
+
+namespace funasr {
+CTokenizer::CTokenizer(const char* sz_yamlfile):m_ready(false)
{
- OpenYaml(szYmlFile);
+ OpenYaml(sz_yamlfile);
}
-CTokenizer::CTokenizer():m_Ready(false)
+CTokenizer::CTokenizer():m_ready(false)
{
}
-void CTokenizer::read_yml(const YAML::Node& node)
+CTokenizer::~CTokenizer()
+{
+}
+
+void CTokenizer::ReadYaml(const YAML::Node& node)
{
if (node.IsMap())
{//锟斤拷map锟斤拷
for (auto it = node.begin(); it != node.end(); ++it)
{
- read_yml(it->second);
+ ReadYaml(it->second);
}
}
if (node.IsSequence()) {//锟斤拷锟斤拷锟斤拷锟斤拷
for (size_t i = 0; i < node.size(); ++i) {
- read_yml(node[i]);
+ ReadYaml(node[i]);
}
}
if (node.IsScalar()) {//锟角憋拷锟斤拷锟斤拷
- cout << node.as<string>() << endl;
+ LOG(INFO) << node.as<string>();
}
}
-bool CTokenizer::OpenYaml(const char* szYmlFile)
+bool CTokenizer::OpenYaml(const char* sz_yamlfile)
{
- YAML::Node m_Config = YAML::LoadFile(szYmlFile);
- if (m_Config.IsNull())
- return false;
+ YAML::Node m_Config;
+ try{
+ m_Config = YAML::LoadFile(sz_yamlfile);
+ }catch(exception const &e){
+ LOG(INFO) << "Error loading file, yaml file error or not exist.";
+ exit(-1);
+ }
+
try
{
auto Tokens = m_Config["token_list"];
@@ -42,8 +57,8 @@
{
if (Tokens[i].IsScalar())
{
- m_ID2Token.push_back(Tokens[i].as<string>());
- m_Token2ID.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
+ m_id2token.push_back(Tokens[i].as<string>());
+ m_token2id.insert(make_pair<string, int>(Tokens[i].as<string>(), i));
}
}
}
@@ -54,97 +69,105 @@
{
if (Puncs[i].IsScalar())
{
- m_ID2Punc.push_back(Puncs[i].as<string>());
- m_Punc2ID.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
+ m_id2punc.push_back(Puncs[i].as<string>());
+ m_punc2id.insert(make_pair<string, int>(Puncs[i].as<string>(), i));
}
}
}
}
catch (YAML::BadFile& e) {
- std::cout << "read error!" << std::endl;
+ LOG(ERROR) << "Read error!";
return false;
}
- m_Ready = true;
- return m_Ready;
+ m_ready = true;
+ return m_ready;
}
-vector<string> CTokenizer::ID2String(vector<int> Input)
+vector<string> CTokenizer::Id2String(vector<int> input)
{
vector<string> result;
- for (auto& item : Input)
+ for (auto& item : input)
{
- result.push_back(m_ID2Token[item]);
+ result.push_back(m_id2token[item]);
}
return result;
}
-int CTokenizer::String2ID(string Input)
+int CTokenizer::String2Id(string input)
{
int nID = 0; // <blank>
- if (m_Token2ID.find(Input) != m_Token2ID.end())
- nID=(m_Token2ID[Input]);
+ if (m_token2id.find(input) != m_token2id.end())
+ nID=(m_token2id[input]);
else
- nID=(m_Token2ID[UNK_CHAR]);
+ nID=(m_token2id[UNK_CHAR]);
return nID;
}
-vector<int> CTokenizer::String2IDs(vector<string> Input)
+vector<int> CTokenizer::String2Ids(vector<string> input)
{
vector<int> result;
- for (auto& item : Input)
+ for (auto& item : input)
{
transform(item.begin(), item.end(), item.begin(), ::tolower);
- if (m_Token2ID.find(item) != m_Token2ID.end())
- result.push_back(m_Token2ID[item]);
+ if (m_token2id.find(item) != m_token2id.end())
+ result.push_back(m_token2id[item]);
else
- result.push_back(m_Token2ID[UNK_CHAR]);
+ result.push_back(m_token2id[UNK_CHAR]);
}
return result;
}
-vector<string> CTokenizer::ID2Punc(vector<int> Input)
+vector<string> CTokenizer::Id2Punc(vector<int> input)
{
vector<string> result;
- for (auto& item : Input)
+ for (auto& item : input)
{
- result.push_back(m_ID2Punc[item]);
+ result.push_back(m_id2punc[item]);
}
return result;
}
-string CTokenizer::ID2Punc(int nPuncID)
+string CTokenizer::Id2Punc(int n_punc_id)
{
- return m_ID2Punc[nPuncID];
+ return m_id2punc[n_punc_id];
}
-vector<int> CTokenizer::Punc2IDs(vector<string> Input)
+vector<int> CTokenizer::Punc2Ids(vector<string> input)
{
vector<int> result;
- for (auto& item : Input)
+ for (auto& item : input)
{
- result.push_back(m_Punc2ID[item]);
+ result.push_back(m_punc2id[item]);
}
return result;
}
-vector<string> CTokenizer::SplitChineseString(const string & strInfo)
+bool CTokenizer::IsPunc(string& Punc)
+{
+ if (m_punc2id.find(Punc) != m_punc2id.end())
+ return true;
+ else
+ return false;
+}
+
+vector<string> CTokenizer::SplitChineseString(const string & str_info)
{
vector<string> list;
- int strSize = strInfo.size();
+ int strSize = str_info.size();
int i = 0;
while (i < strSize) {
int len = 1;
- for (int j = 0; j < 6 && (strInfo[i] & (0x80 >> j)); j++) {
+ for (int j = 0; j < 6 && (str_info[i] & (0x80 >> j)); j++) {
len = j + 1;
}
- list.push_back(strInfo.substr(i, len));
+ list.push_back(str_info.substr(i, len));
i += len;
}
return list;
}
-void CTokenizer::strSplit(const string& str, const char split, vector<string>& res)
+void CTokenizer::StrSplit(const string& str, const char split, vector<string>& res)
{
if (str == "")
{
@@ -161,10 +184,10 @@
}
}
- void CTokenizer::Tokenize(const char* strInfo, vector<string> & strOut, vector<int> & IDOut)
+ void CTokenizer::Tokenize(const char* str_info, vector<string> & str_out, vector<int> & id_out)
{
vector<string> strList;
- strSplit(strInfo,' ', strList);
+ StrSplit(str_info,' ', strList);
string current_eng,current_chinese;
for (auto& item : strList)
{
@@ -178,7 +201,7 @@
{
// for utf-8 chinese
auto chineseList = SplitChineseString(current_chinese);
- strOut.insert(strOut.end(), chineseList.begin(),chineseList.end());
+ str_out.insert(str_out.end(), chineseList.begin(),chineseList.end());
current_chinese = "";
}
current_eng += ch;
@@ -187,7 +210,7 @@
{
if (current_eng.size() > 0)
{
- strOut.push_back(current_eng);
+ str_out.push_back(current_eng);
current_eng = "";
}
current_chinese += ch;
@@ -196,13 +219,15 @@
if (current_chinese.size() > 0)
{
auto chineseList = SplitChineseString(current_chinese);
- strOut.insert(strOut.end(), chineseList.begin(), chineseList.end());
+ str_out.insert(str_out.end(), chineseList.begin(), chineseList.end());
current_chinese = "";
}
if (current_eng.size() > 0)
{
- strOut.push_back(current_eng);
+ str_out.push_back(current_eng);
}
}
- IDOut= String2IDs(strOut);
+ id_out= String2Ids(str_out);
}
+
+} // namespace funasr
\ No newline at end of file
--
Gitblit v1.9.1