雾聪
2023-12-18 f72914003a8c4ab7ae72d52dbd7c5f70ea22313a
add sentence timestamp
10个文件已修改
121 ■■■■■ 已修改文件
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/bin/funasr-onnx-offline.cpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/include/funasrruntime.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/commonfunc.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/funasrruntime.cpp 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/util.cpp 73 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/src/util.h 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/websocket-server-2pass.cpp 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/websocket-server.cpp 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/websocket/bin/websocket-server.h 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -83,6 +83,10 @@
            if(stamp !=""){
                LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << stamp;
            }
            string stamp_sents = FunASRGetStampSents(result);
            if(stamp_sents !=""){
                LOG(INFO)<< wav_ids[i] <<" : "<<stamp_sents;
            }
            float snippet_time = FunASRGetRetSnippetTime(result);
            n_total_length += snippet_time;
            FunASRFreeResult(result);
runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -172,6 +172,10 @@
            if(stamp !=""){
                LOG(INFO)<< wav_id <<" : "<<stamp;
            }
            string stamp_sents = FunASRGetStampSents(result);
            if(stamp_sents !=""){
                LOG(INFO)<< wav_id <<" : "<<stamp_sents;
            }
            snippet_time += FunASRGetRetSnippetTime(result);
            FunASRFreeResult(result);
        }
runtime/onnxruntime/include/funasrruntime.h
@@ -68,6 +68,7 @@
_FUNASRAPI const char*    FunASRGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const char*    FunASRGetStamp(FUNASR_RESULT result);
_FUNASRAPI const char*    FunASRGetStampSents(FUNASR_RESULT result);
_FUNASRAPI const char*    FunASRGetTpassResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const int    FunASRGetRetNumber(FUNASR_RESULT result);
_FUNASRAPI void            FunASRFreeResult(FUNASR_RESULT result);
runtime/onnxruntime/src/commonfunc.h
@@ -9,6 +9,7 @@
{
    std::string msg;
    std::string stamp;
    std::string stamp_sents;
    std::string tpass_msg;
    float snippet_time;
}FUNASR_RECOG_RESULT;
runtime/onnxruntime/src/funasrruntime.cpp
@@ -303,7 +303,9 @@
            p_result->msg = msg_itn;
        }
#endif
        if (!(p_result->stamp).empty()){
            p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
        }
        return p_result;
    }
@@ -399,6 +401,9 @@
            p_result->msg = msg_itn;
        }
#endif
        if (!(p_result->stamp).empty()){
            p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
        }
        return p_result;
    }
@@ -546,7 +551,9 @@
                p_result->tpass_msg = msg_itn;
            }
#endif
            if (!(p_result->stamp).empty()){
                p_result->stamp_sents = funasr::TimestampSentence(p_result->tpass_msg, p_result->stamp);
            }
            if(frame != NULL){
                delete frame;
                frame = NULL;
@@ -603,6 +610,15 @@
        return p_result->stamp.c_str();
    }
        _FUNASRAPI const char* FunASRGetStampSents(FUNASR_RESULT result)
    {
        funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
        if(!p_result)
            return nullptr;
        return p_result->stamp_sents.c_str();
    }
    _FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index)
    {
        funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
runtime/onnxruntime/src/util.cpp
@@ -255,7 +255,8 @@
}
bool TimestampIsPunctuation(const std::string& str) {
    const std::string punctuation = u8",。?、,.?";
    const std::string punctuation = u8",。?、,?";
    // const std::string punctuation = u8",。?、,.?";
    for (char ch : str) {
        if (punctuation.find(ch) == std::string::npos) {
            return false;
@@ -557,6 +558,76 @@
    return timestamps_str;
}
std::string TimestampSentence(std::string &text, std::string &str_time){
    std::vector<std::string> characters;
    funasr::TimestampSplitChiEngCharacters(text, characters);
    vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
    int idx_str = 0, idx_ts = 0;
    int start = -1, end = -1;
    std::string text_seg = "";
    std::string ts_sentences = "";
    std::string ts_sent = "";
    vector<vector<int>> ts_seg;
    while(idx_str < characters.size()){
        if (TimestampIsPunctuation(characters[idx_str])){
            if(ts_seg.size() >0){
                if (ts_seg[0].size() == 2){
                    start = ts_seg[0][0];
                }
                if (ts_seg[ts_seg.size()-1].size() == 2){
                    end = ts_seg[ts_seg.size()-1][1];
                }
            }
            // format
            ts_sent += "{'text':'" + text_seg + "',";
            ts_sent += "'start':'" + to_string(start) + "',";
            ts_sent += "'end':'" + to_string(end) + "',";
            ts_sent += "'ts_list':" + VectorToString(ts_seg) + "}";
            if (idx_str == characters.size()-1){
                ts_sentences += ts_sent;
            } else{
                ts_sentences += ts_sent + ",";
            }
            // clear
            idx_str++;
            text_seg = "";
            ts_sent = "";
            start = 0;
            end = 0;
            ts_seg.clear();
        } else if(idx_ts < timestamps.size()) {
            if (text_seg.empty()){
                text_seg = characters[idx_str];
            }else{
                text_seg += " " + characters[idx_str];
            }
            ts_seg.push_back(timestamps[idx_ts]);
            idx_str++;
            idx_ts++;
        }
    }
    // for none punc results
    if(ts_seg.size() >0){
        if (ts_seg[0].size() == 2){
            start = ts_seg[0][0];
        }
        if (ts_seg[ts_seg.size()-1].size() == 2){
            end = ts_seg[ts_seg.size()-1][1];
        }
        // format
        ts_sent += "{'text':'" + text_seg + "',";
        ts_sent += "'start':'" + to_string(start) + "',";
        ts_sent += "'end':'" + to_string(end) + "',";
        ts_sent += "'ts_list':" + VectorToString(ts_seg) + "}";
        ts_sentences += ts_sent;
    }
    return "[" +ts_sentences + "]";
}
std::vector<std::string> split(const std::string &s, char delim) {
  std::vector<std::string> elems;
  std::stringstream ss(s);
runtime/onnxruntime/src/util.h
@@ -47,7 +47,7 @@
                                  std::vector<std::string> &characters);
std::string VectorToString(const std::vector<std::vector<int>>& vec);                                  
std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time);
std::string TimestampSentence(std::string &text, std::string &str_time);
std::vector<std::string> split(const std::string &s, char delim);
template<typename T>
runtime/websocket/bin/websocket-server-2pass.cpp
@@ -80,6 +80,12 @@
    jsonresult["timestamp"] = tmp_stamp_msg;
  }
  std::string tmp_stamp_sents = FunASRGetStampSents(result);
  if (tmp_stamp_sents != "") {
    LOG(INFO) << "offline stamp_sents : " << tmp_stamp_sents;
    jsonresult["stamp_sents"] = tmp_stamp_sents;
  }
  return jsonresult;
}
// feed buffer to asr engine for decoder
@@ -318,7 +324,7 @@
        data_msg->msg["is_eof"]=true;
        guard_decoder.unlock();
        to_remove.push_back(hdl);
        LOG(INFO)<<"connection is closed: "<<e.what();
        LOG(INFO)<<"connection is closed.";
        
      }
      iter++;
runtime/websocket/bin/websocket-server.cpp
@@ -74,6 +74,7 @@
    if (!buffer.empty() && hotwords_embedding.size() > 0) {
      std::string asr_result;
      std::string stamp_res;
      std::string stamp_sents;
      try{
        FUNASR_RESULT Result = FunOfflineInferBuffer(
            asr_handle, buffer.data(), buffer.size(), RASR_NONE, NULL, 
@@ -81,6 +82,7 @@
        asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;  // get decode result
        stamp_res = ((FUNASR_RECOG_RESULT*)Result)->stamp;
        stamp_sents = ((FUNASR_RECOG_RESULT*)Result)->stamp_sents;
        FunASRFreeResult(Result);
      }catch (std::exception const& e) {
        LOG(ERROR) << e.what();
@@ -94,6 +96,9 @@
        jsonresult["is_final"] = false;
      if(stamp_res != ""){
        jsonresult["timestamp"] = stamp_res;
      }
      if(stamp_sents != ""){
        jsonresult["stamp_sents"] = stamp_sents;
      }
      jsonresult["wav_name"] = wav_name;
@@ -227,7 +232,7 @@
        data_msg->msg["is_eof"]=true;
        guard_decoder.unlock();
        to_remove.push_back(hdl);
        LOG(INFO)<<"connection is closed: "<<e.what();
        LOG(INFO)<<"connection is closed.";
        
      }
      iter++;
runtime/websocket/bin/websocket-server.h
@@ -50,6 +50,7 @@
typedef struct {
    std::string msg="";
    std::string stamp="";
    std::string stamp_sents;
    std::string tpass_msg="";
    float snippet_time=0;
} FUNASR_RECOG_RESULT;