雾聪
2024-01-23 46e9334c3150c942de0a57de3381c19d9db309ed
runtime/onnxruntime/src/util.cpp
@@ -247,6 +247,395 @@
  }
}
// Timestamp Smooth
void TimestampAdd(std::deque<string> &alignment_str1, std::string str_word){
    if(!TimestampIsPunctuation(str_word)){
        alignment_str1.push_front(str_word);
    }
}
bool TimestampIsPunctuation(const std::string& str) {
    const std::string punctuation = u8",。?、,?";
    // const std::string punctuation = u8",。?、,.?";
    for (char ch : str) {
        if (punctuation.find(ch) == std::string::npos) {
            return false;
        }
    }
    return true;
}
vector<vector<int>> ParseTimestamps(const std::string& str) {
    vector<vector<int>> timestamps;
    std::istringstream ss(str);
    std::string segment;
    // skip first'['
    ss.ignore(1);
    while (std::getline(ss, segment, ']')) {
        std::istringstream segmentStream(segment);
        std::string number;
        vector<int> ts;
        // skip'['
        segmentStream.ignore(1);
        while (std::getline(segmentStream, number, ',')) {
            ts.push_back(std::stoi(number));
        }
        if(ts.size() != 2){
            LOG(ERROR) << "ParseTimestamps Failed";
            timestamps.clear();
            return timestamps;
        }
        timestamps.push_back(ts);
        ss.ignore(1);
    }
    return timestamps;
}
bool TimestampIsDigit(U16CHAR_T &u16) {
    return u16 >= L'0' && u16 <= L'9';
}
bool TimestampIsAlpha(U16CHAR_T &u16) {
    return (u16 >= L'A' && u16 <= L'Z') || (u16 >= L'a' && u16 <= L'z');
}
bool TimestampIsPunctuation(U16CHAR_T &u16) {
    // (& ' -) in the dict
    if (u16 == 0x26 || u16 == 0x27 || u16 == 0x2D){
        return false;
    }
    return (u16 >= 0x21 && u16 <= 0x2F)     // 标准ASCII标点
        || (u16 >= 0x3A && u16 <= 0x40)     // 标准ASCII标点
        || (u16 >= 0x5B && u16 <= 0x60)     // 标准ASCII标点
        || (u16 >= 0x7B && u16 <= 0x7E)     // 标准ASCII标点
        || (u16 >= 0x2000 && u16 <= 0x206F) // 常用的Unicode标点
        || (u16 >= 0x3000 && u16 <= 0x303F); // CJK符号和标点
}
void TimestampSplitChiEngCharacters(const std::string &input_str,
                                  std::vector<std::string> &characters) {
  characters.resize(0);
  std::string eng_word = "";
  U16CHAR_T space = 0x0020;
  std::vector<U16CHAR_T> u16_buf;
  u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1));
  U16CHAR_T* pu16 = u16_buf.data();
  U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data();
  size_t ilen = input_str.size();
  size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1);
  for (size_t i = 0; i < len; i++) {
    if (EncodeConverter::IsChineseCharacter(pu16[i])) {
      if(!eng_word.empty()){
        characters.push_back(eng_word);
        eng_word = "";
      }
      U8CHAR_T u8buf[4];
      size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
      u8buf[n] = '\0';
      characters.push_back((const char*)u8buf);
    } else if (TimestampIsDigit(pu16[i]) || TimestampIsPunctuation(pu16[i])){
      if(!eng_word.empty()){
        characters.push_back(eng_word);
        eng_word = "";
      }
      U8CHAR_T u8buf[4];
      size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
      u8buf[n] = '\0';
      characters.push_back((const char*)u8buf);
    } else if (pu16[i] == space){
      if(!eng_word.empty()){
        characters.push_back(eng_word);
        eng_word = "";
      }
    }else{
      U8CHAR_T u8buf[4];
      size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
      u8buf[n] = '\0';
      eng_word += (const char*)u8buf;
    }
  }
  if(!eng_word.empty()){
    characters.push_back(eng_word);
    eng_word = "";
  }
}
std::string VectorToString(const std::vector<std::vector<int>>& vec, bool out_empty) {
    if(vec.size() == 0){
        if(out_empty){
            return "";
        }else{
            return "[]";
        }
    }
    std::ostringstream out;
    out << "[";
    for (size_t i = 0; i < vec.size(); ++i) {
        out << "[";
        for (size_t j = 0; j < vec[i].size(); ++j) {
            out << vec[i][j];
            if (j < vec[i].size() - 1) {
                out << ",";
            }
        }
        out << "]";
        if (i < vec.size() - 1) {
            out << ",";
        }
    }
    out << "]";
    return out.str();
}
std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time){
    vector<vector<int>> timestamps_out;
    std::string timestamps_str = "";
    // process string to vector<string>
    std::vector<std::string> characters;
    funasr::TimestampSplitChiEngCharacters(text, characters);
    std::vector<std::string> characters_itn;
    funasr::TimestampSplitChiEngCharacters(text_itn, characters_itn);
    //convert string to vector<vector<int>>
    vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
    if (timestamps.size() == 0){
        LOG(ERROR) << "Timestamp Smooth Failed: Length of timestamp is zero";
        return timestamps_str;
    }
    // edit distance
    int m = characters.size();
    int n = characters_itn.size();
    std::vector<std::vector<int>> dp(m + 1, std::vector<int>(n + 1, 0));
    // init
    for (int i = 0; i <= m; ++i) {
        dp[i][0] = i;
    }
    for (int j = 0; j <= n; ++j) {
        dp[0][j] = j;
    }
    // dp
    for (int i = 1; i <= m; ++i) {
        for (int j = 1; j <= n; ++j) {
            if (characters[i - 1] == characters_itn[j - 1]) {
                dp[i][j] = dp[i - 1][j - 1];
            } else {
                dp[i][j] = std::min({dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]}) + 1;
            }
        }
    }
    // backtrack
    std::deque<string> alignment_str1, alignment_str2;
    int i = m, j = n;
    while (i > 0 || j > 0) {
        if (i > 0 && j > 0 && dp[i][j] == dp[i - 1][j - 1]) {
            funasr::TimestampAdd(alignment_str1, characters[i - 1]);
            funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
            i -= 1;
            j -= 1;
        } else if (i > 0 && dp[i][j] == dp[i - 1][j] + 1) {
            funasr::TimestampAdd(alignment_str1, characters[i - 1]);
            alignment_str2.push_front("");
            i -= 1;
        } else if (j > 0 && dp[i][j] == dp[i][j - 1] + 1) {
            alignment_str1.push_front("");
            funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
            j -= 1;
        } else{
            funasr::TimestampAdd(alignment_str1, characters[i - 1]);
            funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
            i -= 1;
            j -= 1;
        }
    }
    // smooth
    int itn_count = 0;
    int idx_tp = 0;
    int idx_itn = 0;
    vector<vector<int>> timestamps_tmp;
    for(int index = 0; index < alignment_str1.size(); index++){
        if (alignment_str1[index] == alignment_str2[index]){
            bool subsidy = false;
            if (itn_count > 0 && timestamps_tmp.size() == 0){
                if(idx_tp >= timestamps.size()){
                    LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
                    return timestamps_str;
                }
                timestamps_tmp.push_back(timestamps[idx_tp]);
                subsidy = true;
                itn_count++;
            }
            if (timestamps_tmp.size() > 0){
                if (itn_count > 0){
                    int begin = timestamps_tmp[0][0];
                    int end = timestamps_tmp.back()[1];
                    int total_time = end - begin;
                    int interval = total_time / itn_count;
                    for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
                        vector<int> ts;
                        ts.push_back(begin + interval*idx_cnt);
                        if(idx_cnt == itn_count-1){
                            ts.push_back(end);
                        }else {
                            ts.push_back(begin + interval*(idx_cnt + 1));
                        }
                        timestamps_out.push_back(ts);
                    }
                }
                timestamps_tmp.clear();
            }
            if(!subsidy){
                if(idx_tp >= timestamps.size()){
                    LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
                    return timestamps_str;
                }
                timestamps_out.push_back(timestamps[idx_tp]);
            }
            idx_tp++;
            itn_count = 0;
        }else{
            if (!alignment_str1[index].empty()){
                if(idx_tp >= timestamps.size()){
                    LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
                    return timestamps_str;
                }
                timestamps_tmp.push_back(timestamps[idx_tp]);
                idx_tp++;
            }
            if (!alignment_str2[index].empty()){
                itn_count++;
            }
        }
        // count length of itn
        if (!alignment_str2[index].empty()){
            idx_itn++;
        }
    }
    {
        if (itn_count > 0 && timestamps_tmp.size() == 0){
            if (timestamps_out.size() > 0){
                timestamps_tmp.push_back(timestamps_out.back());
                itn_count++;
                timestamps_out.pop_back();
            } else{
                LOG(ERROR) << "Timestamp Smooth Failed: Last itn has no timestamp.";
                return timestamps_str;
            }
        }
        if (timestamps_tmp.size() > 0){
            if (itn_count > 0){
                int begin = timestamps_tmp[0][0];
                int end = timestamps_tmp.back()[1];
                int total_time = end - begin;
                int interval = total_time / itn_count;
                for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
                    vector<int> ts;
                    ts.push_back(begin + interval*idx_cnt);
                    if(idx_cnt == itn_count-1){
                        ts.push_back(end);
                    }else {
                        ts.push_back(begin + interval*(idx_cnt + 1));
                    }
                    timestamps_out.push_back(ts);
                }
            }
            timestamps_tmp.clear();
        }
    }
    if(timestamps_out.size() != idx_itn){
        LOG(ERROR) << "Timestamp Smooth Failed: Timestamp length does not matched.";
        return timestamps_str;
    }
    timestamps_str = VectorToString(timestamps_out);
    return timestamps_str;
}
std::string TimestampSentence(std::string &text, std::string &str_time){
    std::vector<std::string> characters;
    funasr::TimestampSplitChiEngCharacters(text, characters);
    vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
    int idx_str = 0, idx_ts = 0;
    int start = -1, end = -1;
    std::string text_seg = "";
    std::string ts_sentences = "";
    std::string ts_sent = "";
    vector<vector<int>> ts_seg;
    while(idx_str < characters.size()){
        if (TimestampIsPunctuation(characters[idx_str])){
            if(ts_seg.size() >0){
                if (ts_seg[0].size() == 2){
                    start = ts_seg[0][0];
                }
                if (ts_seg[ts_seg.size()-1].size() == 2){
                    end = ts_seg[ts_seg.size()-1][1];
                }
            }
            // format
            ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
            ts_sent += "\"punc\":\"" + characters[idx_str] + "\",";
            ts_sent += "\"start\":" + to_string(start) + ",";
            ts_sent += "\"end\":" + to_string(end) + ",";
            ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
            if (idx_str == characters.size()-1){
                ts_sentences += ts_sent;
            } else{
                ts_sentences += ts_sent + ",";
            }
            // clear
            text_seg = "";
            ts_sent = "";
            start = 0;
            end = 0;
            ts_seg.clear();
        } else if(idx_ts < timestamps.size()) {
            if (text_seg.empty()){
                text_seg = characters[idx_str];
            }else{
                text_seg += " " + characters[idx_str];
            }
            ts_seg.push_back(timestamps[idx_ts]);
            idx_ts++;
        }
        idx_str++;
    }
    // for none punc results
    if(ts_seg.size() >0){
        if (ts_seg[0].size() == 2){
            start = ts_seg[0][0];
        }
        if (ts_seg[ts_seg.size()-1].size() == 2){
            end = ts_seg[ts_seg.size()-1][1];
        }
        // format
        ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
        ts_sent += "\"punc\":\"\",";
        ts_sent += "\"start\":" + to_string(start) + ",";
        ts_sent += "\"end\":" + to_string(end) + ",";
        ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
        ts_sentences += ts_sent;
    }
    return "[" +ts_sentences + "]";
}
std::vector<std::string> split(const std::string &s, char delim) {
  std::vector<std::string> elems;
  std::stringstream ss(s);
@@ -333,12 +722,23 @@
            int sub_word = !(word.find("@@") == string::npos);
            // process word start and middle part
            if (sub_word) {
                combine += word.erase(word.length() - 2);
                if(!is_combining){
                    begin = timestamp_list[i][0];
                // if badcase: lo@@ chinese
                if (i == raw_char.size()-1 || i<raw_char.size()-1 && IsChinese(raw_char[i+1])){
                    word = word.erase(word.length() - 2) + " ";
                    if (is_combining) {
                        combine += word;
                        is_combining = false;
                        word = combine;
                        combine = "";
                    }
                }else{
                    combine += word.erase(word.length() - 2);
                    if(!is_combining){
                        begin = timestamp_list[i][0];
                    }
                    is_combining = true;
                    continue;
                }
                is_combining = true;
                continue;
            }
            // process word end part
            else if (is_combining) {
@@ -669,4 +1069,9 @@
    ifs_hws.close();
}
void SmoothTimestamps(std::string &str_punc, std::string &str_itn, std::string &str_timetamp){
    return;
}
} // namespace funasr