zhifu gao
2023-10-16 1d7bbbffb6a024a33859b48a7a656d0455dc0be1
funasr/runtime/onnxruntime/src/vocab.cpp
@@ -29,9 +29,19 @@
        exit(-1);
    }
    YAML::Node myList = config["token_list"];
    int i = 0;
    for (YAML::const_iterator it = myList.begin(); it != myList.end(); ++it) {
        vocab.push_back(it->as<string>());
        token_id[it->as<string>()] = i;
        i ++;
    }
}
int Vocab::GetIdByToken(const std::string &token) {
    if (token_id.count(token)) {
        return token_id[token];
    }
    return 0;
}
void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
@@ -65,20 +75,52 @@
    return false;
}
string Vocab::Vector2StringV2(vector<int> in)
string Vocab::WordFormat(std::string word)
{
    if(word == "i"){
        return "I";
    }else if(word == "i'm"){
        return "I'm";
    }else if(word == "i've"){
        return "I've";
    }else if(word == "i'll"){
        return "I'll";
    }else{
        return word;
    }
}
string Vocab::Vector2StringV2(vector<int> in, std::string language)
{
    int i;
    list<string> words;
    int is_pre_english = false;
    int pre_english_len = 0;
    int is_combining = false;
    string combine = "";
    std::string combine = "";
    std::string unicodeChar = "▁";
    for (auto it = in.begin(); it != in.end(); it++) {
        string word = vocab[*it];
        // step1 space character skips
        if (word == "<s>" || word == "</s>" || word == "<unk>")
            continue;
        if (language == "en-bpe"){
            size_t found = word.find(unicodeChar);
            if(found != std::string::npos){
                if (combine != ""){
                    combine = WordFormat(combine);
                    if (words.size() != 0){
                        combine = " " + combine;
                    }
                    words.push_back(combine);
                }
                combine = word.substr(3);
            }else{
                combine += word;
            }
            continue;
        }
        // step2 combie phoneme to full word
        {
            int sub_word = !(word.find("@@") == string::npos);
@@ -137,6 +179,14 @@
        }
    }
    if (language == "en-bpe" and combine != ""){
        combine = WordFormat(combine);
        if (words.size() != 0){
            combine = " " + combine;
        }
        words.push_back(combine);
    }
    stringstream ss;
    for (auto it = words.begin(); it != words.end(); it++) {
        ss << *it;
@@ -150,4 +200,4 @@
    return vocab.size();
}
} // namespace funasr
} // namespace funasr