From d674c29323c930842727d0689100f827798d6ba2 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期一, 11 十二月 2023 15:51:38 +0800
Subject: [PATCH] add timestamp smooth
---
runtime/onnxruntime/bin/funasr-onnx-offline.cpp | 2
runtime/onnxruntime/src/util.cpp | 336 +++++++++++++++++++++++++++++++++++++++++++++++
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp | 4
runtime/onnxruntime/src/vocab.cpp | 21 ++
runtime/onnxruntime/src/funasrruntime.cpp | 19 ++
runtime/onnxruntime/src/util.h | 12 +
runtime/onnxruntime/src/paraformer.cpp | 9 +
7 files changed, 388 insertions(+), 15 deletions(-)
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index 41cd038..b248bca 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -55,7 +55,7 @@
for (size_t i = 0; i < 1; i++)
{
FunOfflineReset(asr_handle, decoder_handle);
- FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
+ FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, true, decoder_handle);
if(result){
FunASRFreeResult(result);
}
@@ -69,7 +69,7 @@
}
gettimeofday(&start, NULL);
- FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
+ FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, true, decoder_handle);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
index 67a267d..87b57a8 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -157,7 +157,7 @@
auto& wav_file = wav_list[i];
auto& wav_id = wav_ids[i];
gettimeofday(&start, NULL);
- FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, false, decoder_handle);
+ FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, hotwords_embedding, 16000, true, decoder_handle);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 3523bba..5c2653f 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -294,6 +294,12 @@
#if !defined(__APPLE__)
if(offline_stream->UseITN() && itn){
string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
+ if(!(p_result->stamp).empty()){
+ std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp);
+ if(!new_stamp.empty()){
+ p_result->stamp = new_stamp;
+ }
+ }
p_result->msg = msg_itn;
}
#endif
@@ -384,6 +390,12 @@
#if !defined(__APPLE__)
if(offline_stream->UseITN() && itn){
string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
+ if(!(p_result->stamp).empty()){
+ std::string new_stamp = funasr::TimestampSmooth(p_result->msg, msg_itn, p_result->stamp);
+ if(!new_stamp.empty()){
+ p_result->stamp = new_stamp;
+ }
+ }
p_result->msg = msg_itn;
}
#endif
@@ -524,6 +536,13 @@
#if !defined(__APPLE__)
if(tpass_stream->UseITN() && itn){
string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
+ // TimestampSmooth
+ if(!(p_result->stamp).empty()){
+ std::string new_stamp = funasr::TimestampSmooth(p_result->tpass_msg, msg_itn, p_result->stamp);
+ if(!new_stamp.empty()){
+ p_result->stamp = new_stamp;
+ }
+ }
p_result->tpass_msg = msg_itn;
}
#endif
diff --git a/runtime/onnxruntime/src/paraformer.cpp b/runtime/onnxruntime/src/paraformer.cpp
index 4e89ea2..b3dc619 100644
--- a/runtime/onnxruntime/src/paraformer.cpp
+++ b/runtime/onnxruntime/src/paraformer.cpp
@@ -300,10 +300,15 @@
Paraformer::~Paraformer()
{
- if(vocab)
+ if(vocab){
delete vocab;
- if(seg_dict)
+ }
+ if(seg_dict){
delete seg_dict;
+ }
+ if(phone_set_){
+ delete phone_set_;
+ }
}
void Paraformer::StartUtterance()
diff --git a/runtime/onnxruntime/src/util.cpp b/runtime/onnxruntime/src/util.cpp
index 005de57..2738d35 100644
--- a/runtime/onnxruntime/src/util.cpp
+++ b/runtime/onnxruntime/src/util.cpp
@@ -247,6 +247,316 @@
}
}
+// Timestamp Smooth
+void TimestampAdd(std::deque<string> &alignment_str1, std::string str_word){
+ if(!TimestampIsPunctuation(str_word)){
+ alignment_str1.push_front(str_word);
+ }
+}
+
+bool TimestampIsPunctuation(const std::string& str) {
+ const std::string punctuation = u8"锛屻�傦紵銆�,.?";
+ for (char ch : str) {
+ if (punctuation.find(ch) == std::string::npos) {
+ return false;
+ }
+ }
+ return true;
+}
+
+vector<vector<int>> ParseTimestamps(const std::string& str) {
+ vector<vector<int>> timestamps;
+ std::istringstream ss(str);
+ std::string segment;
+
+ // skip first'['
+ ss.ignore(1);
+
+ while (std::getline(ss, segment, ']')) {
+ std::istringstream segmentStream(segment);
+ std::string number;
+ vector<int> ts;
+
+ // skip'['
+ segmentStream.ignore(1);
+
+ while (std::getline(segmentStream, number, ',')) {
+ ts.push_back(std::stoi(number));
+ }
+ if(ts.size() != 2){
+ LOG(ERROR) << "ParseTimestamps Failed";
+ timestamps.clear();
+ return timestamps;
+ }
+ timestamps.push_back(ts);
+ ss.ignore(1);
+ }
+
+ return timestamps;
+}
+
+bool TimestampIsDigit(U16CHAR_T &u16) {
+ return u16 >= L'0' && u16 <= L'9';
+}
+
+bool TimestampIsAlpha(U16CHAR_T &u16) {
+ return (u16 >= L'A' && u16 <= L'Z') || (u16 >= L'a' && u16 <= L'z');
+}
+
+bool TimestampIsPunctuation(U16CHAR_T &u16) {
+ return (u16 >= 0x21 && u16 <= 0x2F) // 鏍囧噯ASCII鏍囩偣
+ || (u16 >= 0x3A && u16 <= 0x40) // 鏍囧噯ASCII鏍囩偣
+ || (u16 >= 0x5B && u16 <= 0x60) // 鏍囧噯ASCII鏍囩偣
+ || (u16 >= 0x7B && u16 <= 0x7E) // 鏍囧噯ASCII鏍囩偣
+ || (u16 >= 0x2000 && u16 <= 0x206F) // 甯哥敤鐨刄nicode鏍囩偣
+ || (u16 >= 0x3000 && u16 <= 0x303F); // CJK绗﹀彿鍜屾爣鐐�
+}
+
+void TimestampSplitChiEngCharacters(const std::string &input_str,
+ std::vector<std::string> &characters) {
+ characters.resize(0);
+ std::string eng_word = "";
+ U16CHAR_T space = 0x0020;
+ std::vector<U16CHAR_T> u16_buf;
+ u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1));
+ U16CHAR_T* pu16 = u16_buf.data();
+ U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data();
+ size_t ilen = input_str.size();
+ size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1);
+ for (size_t i = 0; i < len; i++) {
+ if (EncodeConverter::IsChineseCharacter(pu16[i])) {
+ if(!eng_word.empty()){
+ characters.push_back(eng_word);
+ eng_word = "";
+ }
+ U8CHAR_T u8buf[4];
+ size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
+ u8buf[n] = '\0';
+ characters.push_back((const char*)u8buf);
+ } else if (TimestampIsDigit(pu16[i]) || TimestampIsPunctuation(pu16[i])){
+ if(!eng_word.empty()){
+ characters.push_back(eng_word);
+ eng_word = "";
+ }
+ U8CHAR_T u8buf[4];
+ size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
+ u8buf[n] = '\0';
+ characters.push_back((const char*)u8buf);
+ } else if (pu16[i] == space){
+ if(!eng_word.empty()){
+ characters.push_back(eng_word);
+ eng_word = "";
+ }
+ }else{
+ U8CHAR_T u8buf[4];
+ size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
+ u8buf[n] = '\0';
+ eng_word += (const char*)u8buf;
+ }
+ }
+ if(!eng_word.empty()){
+ characters.push_back(eng_word);
+ eng_word = "";
+ }
+}
+
+std::string VectorToString(const std::vector<std::vector<int>>& vec) {
+ if(vec.size() == 0){
+ return "";
+ }
+ std::ostringstream out;
+ out << "[";
+
+ for (size_t i = 0; i < vec.size(); ++i) {
+ out << "[";
+ for (size_t j = 0; j < vec[i].size(); ++j) {
+ out << vec[i][j];
+ if (j < vec[i].size() - 1) {
+ out << ",";
+ }
+ }
+ out << "]";
+ if (i < vec.size() - 1) {
+ out << ",";
+ }
+ }
+
+ out << "]";
+ return out.str();
+}
+
+std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time){
+ vector<vector<int>> timestamps_out;
+ std::string timestamps_str = "";
+ // process string to vector<string>
+ std::vector<std::string> characters;
+ funasr::TimestampSplitChiEngCharacters(text, characters);
+
+ std::vector<std::string> characters_itn;
+ funasr::TimestampSplitChiEngCharacters(text_itn, characters_itn);
+
+ //convert string to vector<vector<int>>
+ vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
+
+ if (timestamps.size() == 0){
+ LOG(ERROR) << "Timestamp Smooth Failed: Length of timestamp is zero";
+ return timestamps_str;
+ }
+
+ // edit distance
+ int m = characters.size();
+ int n = characters_itn.size();
+ std::vector<std::vector<int>> dp(m + 1, std::vector<int>(n + 1, 0));
+
+ // init
+ for (int i = 0; i <= m; ++i) {
+ dp[i][0] = i;
+ }
+ for (int j = 0; j <= n; ++j) {
+ dp[0][j] = j;
+ }
+
+ // dp
+ for (int i = 1; i <= m; ++i) {
+ for (int j = 1; j <= n; ++j) {
+ if (characters[i - 1] == characters_itn[j - 1]) {
+ dp[i][j] = dp[i - 1][j - 1];
+ } else {
+ dp[i][j] = std::min({dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]}) + 1;
+ }
+ }
+ }
+
+ // backtrack
+ std::deque<string> alignment_str1, alignment_str2;
+ int i = m, j = n;
+ while (i > 0 || j > 0) {
+ if (i > 0 && j > 0 && dp[i][j] == dp[i - 1][j - 1]) {
+ funasr::TimestampAdd(alignment_str1, characters[i - 1]);
+ funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
+ i -= 1;
+ j -= 1;
+ } else if (i > 0 && dp[i][j] == dp[i - 1][j] + 1) {
+ funasr::TimestampAdd(alignment_str1, characters[i - 1]);
+ alignment_str2.push_front("");
+ i -= 1;
+ } else if (j > 0 && dp[i][j] == dp[i][j - 1] + 1) {
+ alignment_str1.push_front("");
+ funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
+ j -= 1;
+ } else{
+ funasr::TimestampAdd(alignment_str1, characters[i - 1]);
+ funasr::TimestampAdd(alignment_str2, characters_itn[j - 1]);
+ i -= 1;
+ j -= 1;
+ }
+ }
+
+ // smooth
+ int itn_count = 0;
+ int idx_tp = 0;
+ int idx_itn = 0;
+ vector<vector<int>> timestamps_tmp;
+ for(int index = 0; index < alignment_str1.size(); index++){
+ if (alignment_str1[index] == alignment_str2[index]){
+ bool subsidy = false;
+ if (itn_count > 0 && timestamps_tmp.size() == 0){
+ if(idx_tp >= timestamps.size()){
+ LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
+ return timestamps_str;
+ }
+ timestamps_tmp.push_back(timestamps[idx_tp]);
+ subsidy = true;
+ itn_count++;
+ }
+
+ if (timestamps_tmp.size() > 0){
+ if (itn_count > 0){
+ int begin = timestamps_tmp[0][0];
+ int end = timestamps_tmp.back()[1];
+ int total_time = end - begin;
+ int interval = total_time / itn_count;
+ for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
+ vector<int> ts;
+ ts.push_back(begin + interval*idx_cnt);
+ if(idx_cnt == itn_count-1){
+ ts.push_back(end);
+ }else {
+ ts.push_back(begin + interval*(idx_cnt + 1));
+ }
+ timestamps_out.push_back(ts);
+ }
+ }
+ timestamps_tmp.clear();
+ }
+ if(!subsidy){
+ if(idx_tp >= timestamps.size()){
+ LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
+ return timestamps_str;
+ }
+ timestamps_out.push_back(timestamps[idx_tp]);
+ }
+ idx_tp++;
+ itn_count = 0;
+ }else{
+ if (!alignment_str1[index].empty()){
+ if(idx_tp >= timestamps.size()){
+ LOG(ERROR) << "Timestamp Smooth Failed: Index of tp is out of range. ";
+ return timestamps_str;
+ }
+ timestamps_tmp.push_back(timestamps[idx_tp]);
+ idx_tp++;
+ }
+ if (!alignment_str2[index].empty()){
+ itn_count++;
+ }
+ }
+ // count length of itn
+ if (!alignment_str2[index].empty()){
+ idx_itn++;
+ }
+ }
+ {
+ if (itn_count > 0 && timestamps_tmp.size() == 0){
+ if (timestamps_out.size() > 0){
+ timestamps_tmp.push_back(timestamps_out.back());
+ itn_count++;
+ timestamps_out.pop_back();
+ } else{
+ LOG(ERROR) << "Timestamp Smooth Failed: Last itn has no timestamp.";
+ return timestamps_str;
+ }
+ }
+
+ if (timestamps_tmp.size() > 0){
+ if (itn_count > 0){
+ int begin = timestamps_tmp[0][0];
+ int end = timestamps_tmp.back()[1];
+ int total_time = end - begin;
+ int interval = total_time / itn_count;
+ for(int idx_cnt=0; idx_cnt < itn_count; idx_cnt++){
+ vector<int> ts;
+ ts.push_back(begin + interval*idx_cnt);
+ if(idx_cnt == itn_count-1){
+ ts.push_back(end);
+ }else {
+ ts.push_back(begin + interval*(idx_cnt + 1));
+ }
+ timestamps_out.push_back(ts);
+ }
+ }
+ timestamps_tmp.clear();
+ }
+ }
+ if(timestamps_out.size() != idx_itn){
+ LOG(ERROR) << "Timestamp Smooth Failed: Timestamp length does not matched.";
+ return timestamps_str;
+ }
+
+ timestamps_str = VectorToString(timestamps_out);
+ return timestamps_str;
+}
+
std::vector<std::string> split(const std::string &s, char delim) {
std::vector<std::string> elems;
std::stringstream ss(s);
@@ -333,12 +643,23 @@
int sub_word = !(word.find("@@") == string::npos);
// process word start and middle part
if (sub_word) {
- combine += word.erase(word.length() - 2);
- if(!is_combining){
- begin = timestamp_list[i][0];
+ // if badcase: lo@@ chinese
+ if (i == raw_char.size()-1 || i<raw_char.size()-1 && IsChinese(raw_char[i+1])){
+ word = word.erase(word.length() - 2) + " ";
+ if (is_combining) {
+ combine += word;
+ is_combining = false;
+ word = combine;
+ combine = "";
+ }
+ }else{
+ combine += word.erase(word.length() - 2);
+ if(!is_combining){
+ begin = timestamp_list[i][0];
+ }
+ is_combining = true;
+ continue;
}
- is_combining = true;
- continue;
}
// process word end part
else if (is_combining) {
@@ -669,4 +990,9 @@
ifs_hws.close();
}
+void SmoothTimestamps(std::string &str_punc, std::string &str_itn, std::string &str_timetamp){
+
+ return;
+}
+
} // namespace funasr
diff --git a/runtime/onnxruntime/src/util.h b/runtime/onnxruntime/src/util.h
index 3ccfa6b..46d24b3 100644
--- a/runtime/onnxruntime/src/util.h
+++ b/runtime/onnxruntime/src/util.h
@@ -3,11 +3,13 @@
#include <vector>
#include <memory>
#include <unordered_map>
+#include <deque>
#include "tensor.h"
using namespace std;
namespace funasr {
+typedef unsigned short U16CHAR_T;
extern float *LoadParams(const char *filename);
extern void SaveDataFile(const char *filename, void *data, uint32_t len);
@@ -35,6 +37,16 @@
std::vector<std::string> &chinese_characters);
void SplitChiEngCharacters(const std::string &input_str,
std::vector<std::string> &characters);
+void TimestampAdd(std::deque<string> &alignment_str1, std::string str_word);
+vector<vector<int>> ParseTimestamps(const std::string& str);
+bool TimestampIsDigit(U16CHAR_T &u16);
+bool TimestampIsAlpha(U16CHAR_T &u16);
+bool TimestampIsPunctuation(U16CHAR_T &u16);
+bool TimestampIsPunctuation(const std::string& str);
+void TimestampSplitChiEngCharacters(const std::string &input_str,
+ std::vector<std::string> &characters);
+std::string VectorToString(const std::vector<std::vector<int>>& vec);
+std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time);
std::vector<std::string> split(const std::string &s, char delim);
diff --git a/runtime/onnxruntime/src/vocab.cpp b/runtime/onnxruntime/src/vocab.cpp
index d29281c..20571c9 100644
--- a/runtime/onnxruntime/src/vocab.cpp
+++ b/runtime/onnxruntime/src/vocab.cpp
@@ -120,8 +120,8 @@
std::string combine = "";
std::string unicodeChar = "鈻�";
- for (auto it = in.begin(); it != in.end(); it++) {
- string word = vocab[*it];
+ for (i=0; i<in.size(); i++){
+ string word = vocab[in[i]];
// step1 space character skips
if (word == "<s>" || word == "</s>" || word == "<unk>")
continue;
@@ -146,9 +146,20 @@
int sub_word = !(word.find("@@") == string::npos);
// process word start and middle part
if (sub_word) {
- combine += word.erase(word.length() - 2);
- is_combining = true;
- continue;
+ // if badcase: lo@@ chinese
+ if (i == in.size()-1 || i<in.size()-1 && IsChinese(vocab[in[i+1]])){
+ word = word.erase(word.length() - 2) + " ";
+ if (is_combining) {
+ combine += word;
+ is_combining = false;
+ word = combine;
+ combine = "";
+ }
+ }else{
+ combine += word.erase(word.length() - 2);
+ is_combining = true;
+ continue;
+ }
}
// process word end part
else if (is_combining) {
--
Gitblit v1.9.1