kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
runtime/onnxruntime/src/vocab.cpp
@@ -14,7 +14,13 @@
Vocab::Vocab(const char *filename)
{
    ifstream in(filename);
    LoadVocabFromJson(filename);
}
Vocab::Vocab(const char *filename, const char *lex_file)
{
    ifstream in(filename);
    LoadVocabFromYaml(filename);
    LoadLex(lex_file);
}
Vocab::~Vocab()
{
@@ -37,11 +43,56 @@
    }
}
int Vocab::GetIdByToken(const std::string &token) {
    if (token_id.count(token)) {
        return token_id[token];
void Vocab::LoadVocabFromJson(const char* filename){
    nlohmann::json json_array;
    std::ifstream file(filename);
    if (file.is_open()) {
        file >> json_array;
        file.close();
    } else {
        LOG(INFO) << "Error loading token file, token file error or not exist.";
        exit(-1);
    }
    return 0;
    int i = 0;
    for (const auto& element : json_array) {
        vocab.push_back(element);
        token_id[element] = i;
        i++;
    }
}
void Vocab::LoadLex(const char* filename){
    std::ifstream file(filename);
    std::string line;
    while (std::getline(file, line)) {
        std::string key, value;
        std::istringstream iss(line);
        std::getline(iss, key, '\t');
        std::getline(iss, value);
        if (!key.empty() && !value.empty()) {
            lex_map[key] = value;
        }
    }
    file.close();
}
string Vocab::Word2Lex(const std::string &word) const {
    auto it = lex_map.find(word);
    if (it != lex_map.end()) {
        return it->second;
    }
    return "";
}
int Vocab::GetIdByToken(const std::string &token) const {
    auto it = token_id.find(token);
    if (it != token_id.end()) {
        return it->second;
    }
    return -1;
}
void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
@@ -120,8 +171,8 @@
    std::string combine = "";
    std::string unicodeChar = "▁";
    for (auto it = in.begin(); it != in.end(); it++) {
        string word = vocab[*it];
    for (i=0; i<in.size(); i++){
        string word = vocab[in[i]];
        // step1 space character skips
        if (word == "<s>" || word == "</s>" || word == "<unk>")
            continue;
@@ -146,9 +197,20 @@
            int sub_word = !(word.find("@@") == string::npos);
            // process word start and middle part
            if (sub_word) {
                combine += word.erase(word.length() - 2);
                is_combining = true;
                continue;
                // if badcase: lo@@ chinese
                if (i == in.size()-1 || i<in.size()-1 && IsChinese(vocab[in[i+1]])){
                    word = word.erase(word.length() - 2) + " ";
                    if (is_combining) {
                        combine += word;
                        is_combining = false;
                        word = combine;
                        combine = "";
                    }
                }else{
                    combine += word.erase(word.length() - 2);
                    is_combining = true;
                    continue;
                }
            }
            // process word end part
            else if (is_combining) {
@@ -199,7 +261,7 @@
        }
    }
    if (language == "en-bpe" and combine != ""){
    if (language == "en-bpe" && combine != ""){
        combine = WordFormat(combine);
        if (words.size() != 0){
            combine = " " + combine;