雾聪
2024-09-25 d62d237a76e423fd1eec31e662162c135d2f93f5
runtime/onnxruntime/src/util.cpp
@@ -255,7 +255,8 @@
}
bool TimestampIsPunctuation(const std::string& str) {
    const std::string punctuation = u8",。?、,.?";
    const std::string punctuation = u8",。?、,?";
    // const std::string punctuation = u8",。?、,.?";
    for (char ch : str) {
        if (punctuation.find(ch) == std::string::npos) {
            return false;
@@ -304,6 +305,10 @@
}
bool TimestampIsPunctuation(U16CHAR_T &u16) {
    // (& ' -) in the dict
    if (u16 == 0x26 || u16 == 0x27 || u16 == 0x2D){
        return false;
    }
    return (u16 >= 0x21 && u16 <= 0x2F)     // 标准ASCII标点
        || (u16 >= 0x3A && u16 <= 0x40)     // 标准ASCII标点
        || (u16 >= 0x5B && u16 <= 0x60)     // 标准ASCII标点
@@ -360,9 +365,13 @@
  }
}
std::string VectorToString(const std::vector<std::vector<int>>& vec) {
std::string VectorToString(const std::vector<std::vector<int>>& vec, bool out_empty) {
    if(vec.size() == 0){
        return "";
        if(out_empty){
            return "";
        }else{
            return "[]";
        }
    }
    std::ostringstream out;
    out << "[";
@@ -557,6 +566,76 @@
    return timestamps_str;
}
std::string TimestampSentence(std::string &text, std::string &str_time){
    std::vector<std::string> characters;
    funasr::TimestampSplitChiEngCharacters(text, characters);
    vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
    int idx_str = 0, idx_ts = 0;
    int start = -1, end = -1;
    std::string text_seg = "";
    std::string ts_sentences = "";
    std::string ts_sent = "";
    vector<vector<int>> ts_seg;
    while(idx_str < characters.size()){
        if (TimestampIsPunctuation(characters[idx_str])){
            if(ts_seg.size() >0){
                if (ts_seg[0].size() == 2){
                    start = ts_seg[0][0];
                }
                if (ts_seg[ts_seg.size()-1].size() == 2){
                    end = ts_seg[ts_seg.size()-1][1];
                }
            }
            // format
            ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
            ts_sent += "\"punc\":\"" + characters[idx_str] + "\",";
            ts_sent += "\"start\":" + to_string(start) + ",";
            ts_sent += "\"end\":" + to_string(end) + ",";
            ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
            if (idx_str == characters.size()-1){
                ts_sentences += ts_sent;
            } else{
                ts_sentences += ts_sent + ",";
            }
            // clear
            text_seg = "";
            ts_sent = "";
            start = 0;
            end = 0;
            ts_seg.clear();
        } else if(idx_ts < timestamps.size()) {
            if (text_seg.empty()){
                text_seg = characters[idx_str];
            }else{
                text_seg += " " + characters[idx_str];
            }
            ts_seg.push_back(timestamps[idx_ts]);
            idx_ts++;
        }
        idx_str++;
    }
    // for none punc results
    if(ts_seg.size() >0){
        if (ts_seg[0].size() == 2){
            start = ts_seg[0][0];
        }
        if (ts_seg[ts_seg.size()-1].size() == 2){
            end = ts_seg[ts_seg.size()-1][1];
        }
        // format
        ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
        ts_sent += "\"punc\":\"\",";
        ts_sent += "\"start\":" + to_string(start) + ",";
        ts_sent += "\"end\":" + to_string(end) + ",";
        ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
        ts_sentences += ts_sent;
    }
    return "[" +ts_sentences + "]";
}
std::vector<std::string> split(const std::string &s, char delim) {
  std::vector<std::string> elems;
  std::stringstream ss(s);
@@ -565,6 +644,21 @@
    elems.push_back(item);
  }
  return elems;
}
std::vector<std::string> SplitStr(const std::string &s, string delimiter) {
    std::vector<std::string> tokens;
    size_t start = 0;
    size_t end = s.find(delimiter);
    while (end != std::string::npos) {
        tokens.push_back(s.substr(start, end - start));
        start = end + delimiter.length();
        end = s.find(delimiter, start);
    }
    tokens.push_back(s.substr(start, end - start));
    return tokens;
}
template<typename T>
@@ -791,6 +885,15 @@
                sum -=(1.0 - 1e-4);
            }            
        }
        // fix case: sum > 1
        int cif_idx = cif_peak.size()-1;
        while(sum>=1.0 - 1e-4 && cif_idx >= 0 ){
            if(cif_peak[cif_idx] < 1.0 - 1e-4){
                cif_peak[cif_idx] = sum;
                sum -=(1.0 - 1e-4);
            }
            cif_idx--;
        }
        fire_place.clear();
        for (int i = 0; i < num_frames; i++) {