From fbd9fbbde066a483fb903fe9c6c76fb95bc6fc2b Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 17 八月 2023 17:13:37 +0800
Subject: [PATCH] update timestamp

---
 funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp       |    4 
 funasr/runtime/onnxruntime/src/audio.cpp                     |   17 ++
 funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp   |    5 
 funasr/runtime/onnxruntime/src/funasrruntime.cpp             |   43 ++++++
 funasr/runtime/python/onnxruntime/demo_paraformer_offline.py |    1 
 funasr/runtime/onnxruntime/src/paraformer.h                  |    5 
 funasr/runtime/onnxruntime/src/commonfunc.h                  |    1 
 funasr/runtime/onnxruntime/src/util.h                        |    1 
 funasr/runtime/onnxruntime/include/audio.h                   |    1 
 funasr/runtime/onnxruntime/src/util.cpp                      |   10 +
 funasr/runtime/onnxruntime/src/paraformer.cpp                |  254 +++++++++++++++++++++++++++++++++++++++++-
 funasr/runtime/onnxruntime/include/funasrruntime.h           |    1 
 funasr/runtime/onnxruntime/src/vocab.h                       |    4 
 funasr/runtime/onnxruntime/src/vocab.cpp                     |    8 
 14 files changed, 335 insertions(+), 20 deletions(-)

diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index 85d6f03..83eb742 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -63,7 +63,10 @@
         if(result){
             string msg = FunASRGetResult(result, 0);
             LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg;
-
+            string stamp = FunASRGetStamp(result);
+            if(stamp !=""){
+                LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << stamp;
+            }
             float snippet_time = FunASRGetRetSnippetTime(result);
             n_total_length += snippet_time;
             FunASRFreeResult(result);
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
index caa8605..c43fd8d 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -128,6 +128,10 @@
         {
             string msg = FunASRGetResult(result, 0);
             LOG(INFO)<< wav_id <<" : "<<msg;
+            string stamp = FunASRGetStamp(result);
+            if(stamp !=""){
+                LOG(INFO)<< wav_id <<" : "<<stamp;
+            }
             snippet_time += FunASRGetRetSnippetTime(result);
             FunASRFreeResult(result);
         }
diff --git a/funasr/runtime/onnxruntime/include/audio.h b/funasr/runtime/onnxruntime/include/audio.h
index c8ca876..77a0021 100644
--- a/funasr/runtime/onnxruntime/include/audio.h
+++ b/funasr/runtime/onnxruntime/include/audio.h
@@ -69,6 +69,7 @@
     int FetchChunck(AudioFrame *&frame);
     int FetchTpass(AudioFrame *&frame);
     int Fetch(float *&dout, int &len, int &flag);
+    int Fetch(float *&dout, int &len, int &flag, float &start_time);
     void Padding();
     void Split(OfflineStream* offline_streamj);
     void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
diff --git a/funasr/runtime/onnxruntime/include/funasrruntime.h b/funasr/runtime/onnxruntime/include/funasrruntime.h
index c1059a6..30aada8 100644
--- a/funasr/runtime/onnxruntime/include/funasrruntime.h
+++ b/funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -68,6 +68,7 @@
 _FUNASRAPI FUNASR_RESULT	FunASRInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
 
 _FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT result,int n_index);
+_FUNASRAPI const char*	FunASRGetStamp(FUNASR_RESULT result);
 _FUNASRAPI const char*	FunASRGetTpassResult(FUNASR_RESULT result,int n_index);
 _FUNASRAPI const int	FunASRGetRetNumber(FUNASR_RESULT result);
 _FUNASRAPI void			FunASRFreeResult(FUNASR_RESULT result);
diff --git a/funasr/runtime/onnxruntime/src/audio.cpp b/funasr/runtime/onnxruntime/src/audio.cpp
index 2ba9c30..a882078 100644
--- a/funasr/runtime/onnxruntime/src/audio.cpp
+++ b/funasr/runtime/onnxruntime/src/audio.cpp
@@ -980,6 +980,23 @@
     }
 }
 
+int Audio::Fetch(float *&dout, int &len, int &flag, float &start_time)
+{
+    if (frame_queue.size() > 0) {
+        AudioFrame *frame = frame_queue.front();
+        frame_queue.pop();
+
+        start_time = (float)(frame->GetStart())/MODEL_SAMPLE_RATE;
+        dout = speech_data + frame->GetStart();
+        len = frame->GetLen();
+        delete frame;
+        flag = S_END;
+        return 1;
+    } else {
+        return 0;
+    }
+}
+
 void Audio::Padding()
 {
     float num_samples = speech_len;
diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h
index 103e329..d7e5f13 100644
--- a/funasr/runtime/onnxruntime/src/commonfunc.h
+++ b/funasr/runtime/onnxruntime/src/commonfunc.h
@@ -5,6 +5,7 @@
 typedef struct
 {
     std::string msg="";
+    std::string stamp="";
     std::string tpass_msg="";
     float snippet_time=0;
 }FUNASR_RECOG_RESULT;
diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
index 991e516..a4753c5 100644
--- a/funasr/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -246,9 +246,22 @@
 
 		int n_step = 0;
 		int n_total = audio.GetQueueSize();
-		while (audio.Fetch(buff, len, flag) > 0) {
+		float start_time = 0.0;
+		while (audio.Fetch(buff, len, flag, start_time) > 0) {
 			string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
-			p_result->msg += msg;
+			std::vector<std::string> msg_vec = funasr::split(msg, '|');
+			p_result->msg += msg_vec[0];
+			//timestamp
+			if(msg_vec.size() > 1){
+				std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
+				std::string cur_stamp = "";
+				for(int i=0; i<msg_stamp.size()-1; i+=2){
+					float begin = std::stof(msg_stamp[i])+start_time;
+					float end = std::stof(msg_stamp[i+1])+start_time;
+					cur_stamp += "["+std::to_string(begin)+","+std::to_string(end)+"],";
+				}
+				p_result->stamp += cur_stamp;
+			}
 			n_step++;
 			if (fn_callback)
 				fn_callback(n_step, n_total);
@@ -293,9 +306,22 @@
 		int flag = 0;
 		int n_step = 0;
 		int n_total = audio.GetQueueSize();
-		while (audio.Fetch(buff, len, flag) > 0) {
+		float start_time = 0.0;
+		while (audio.Fetch(buff, len, flag, start_time) > 0) {
 			string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
-			p_result->msg+= msg;
+			std::vector<std::string> msg_vec = funasr::split(msg, '|');
+			p_result->msg += msg_vec[0];
+			//timestamp
+			if(msg_vec.size() > 1){
+				std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
+				std::string cur_stamp = "";
+				for(int i=0; i<msg_stamp.size()-1; i+=2){
+					float begin = std::stof(msg_stamp[i])+start_time;
+					float end = std::stof(msg_stamp[i+1])+start_time;
+					cur_stamp += "["+std::to_string(begin)+","+std::to_string(end)+"],";
+				}
+				p_result->stamp += cur_stamp;
+			}
 			n_step++;
 			if (fn_callback)
 				fn_callback(n_step, n_total);
@@ -431,6 +457,15 @@
 		return p_result->msg.c_str();
 	}
 
+	_FUNASRAPI const char* FunASRGetStamp(FUNASR_RESULT result)
+	{
+		funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
+		if(!p_result)
+			return nullptr;
+
+		return p_result->stamp.c_str();
+	}
+
 	_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index)
 	{
 		funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index ef2a182..e2c695c 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -46,10 +46,11 @@
     GetInputName(m_session_.get(), strName,1);
     m_strInputNames.push_back(strName);
     
-    GetOutputName(m_session_.get(), strName);
-    m_strOutputNames.push_back(strName);
-    GetOutputName(m_session_.get(), strName,1);
-    m_strOutputNames.push_back(strName);
+    size_t numOutputNodes = m_session_->GetOutputCount();
+    for(int index=0; index<numOutputNodes; index++){
+        GetOutputName(m_session_.get(), strName, index);
+        m_strOutputNames.push_back(strName);
+    }
 
     for (auto& item : m_strInputNames)
         m_szInputNames.push_back(item.c_str());
@@ -274,7 +275,7 @@
     }
 }
 
-string Paraformer::GreedySearch(float * in, int n_len,  int64_t token_nums)
+string Paraformer::GreedySearch(float * in, int n_len,  int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
 {
     vector<int> hyps;
     int Tmax = n_len;
@@ -284,8 +285,229 @@
         FindMax(in + i * token_nums, token_nums, max_val, max_idx);
         hyps.push_back(max_idx);
     }
+    if(!is_stamp){
+        return vocab->Vector2StringV2(hyps);
+    }else{
+        std::vector<string> char_list;
+        std::vector<std::vector<float>> timestamp_list;
+        std::string res_str;
+        vocab->Vector2String(hyps, char_list);
+        std::vector<string> raw_char(char_list);
+        TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list);
 
-    return vocab->Vector2StringV2(hyps);
+        return PostProcess(raw_char, timestamp_list);
+    }
+}
+
+string Paraformer::PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>> &timestamp_list){
+    std::vector<std::vector<float>> timestamp_merge;
+    int i;
+    list<string> words;
+    int is_pre_english = false;
+    int pre_english_len = 0;
+    int is_combining = false;
+    string combine = "";
+
+    float begin=-1;
+    for (i=0; i<raw_char.size(); i++){
+        string word = raw_char[i];
+        // step1 space character skips
+        if (word == "<s>" || word == "</s>" || word == "<unk>")
+            continue;
+        // step2 combie phoneme to full word
+        {
+            int sub_word = !(word.find("@@") == string::npos);
+            // process word start and middle part
+            if (sub_word) {
+                combine += word.erase(word.length() - 2);
+                if(!is_combining){
+                    begin = timestamp_list[i][0];
+                }
+                is_combining = true;
+                continue;
+            }
+            // process word end part
+            else if (is_combining) {
+                combine += word;
+                is_combining = false;
+                word = combine;
+                combine = "";
+            }
+        }
+
+        // step3 process english word deal with space , turn abbreviation to upper case
+        {
+            // input word is chinese, not need process 
+            if (vocab->IsChinese(word)) {
+                words.push_back(word);
+                timestamp_merge.emplace_back(timestamp_list[i]);
+                is_pre_english = false;
+            }
+            // input word is english word
+            else {
+                // pre word is chinese
+                if (!is_pre_english) {
+                    // word[0] = word[0] - 32;
+                    words.push_back(word);
+                    begin = (begin==-1)?timestamp_list[i][0]:begin;
+                    std::vector<float> vec = {begin, timestamp_list[i][1]};
+                    timestamp_merge.emplace_back(vec);
+                    begin = -1;
+                    pre_english_len = word.size();
+                }
+                // pre word is english word
+                else {
+                    // single letter turn to upper case
+                    // if (word.size() == 1) {
+                    //     word[0] = word[0] - 32;
+                    // }
+
+                    if (pre_english_len > 1) {
+                        words.push_back(" ");
+                        words.push_back(word);
+                        begin = (begin==-1)?timestamp_list[i][0]:begin;
+                        std::vector<float> vec = {begin, timestamp_list[i][1]};
+                        timestamp_merge.emplace_back(vec);
+                        begin = -1;
+                        pre_english_len = word.size();
+                    }
+                    else {
+                        // if (word.size() > 1) {
+                        //     words.push_back(" ");
+                        // }
+                        words.push_back(" ");
+                        words.push_back(word);
+                        begin = (begin==-1)?timestamp_list[i][0]:begin;
+                        std::vector<float> vec = {begin, timestamp_list[i][1]};
+                        timestamp_merge.emplace_back(vec);
+                        begin = -1;
+                        pre_english_len = word.size();
+                    }
+                }
+                is_pre_english = true;
+            }
+        }
+    }
+    string stamp_str="";
+    for (i=0; i<timestamp_list.size(); i++) {
+        stamp_str += std::to_string(timestamp_list[i][0]);
+        stamp_str += ", ";
+        stamp_str += std::to_string(timestamp_list[i][1]);
+        if(i!=timestamp_list.size()-1){
+            stamp_str += ",";
+        }
+    }
+
+    stringstream ss;
+    for (auto it = words.begin(); it != words.end(); it++) {
+        ss << *it;
+    }
+
+    return ss.str()+" | "+stamp_str;
+}
+
+void Paraformer::TimestampOnnx(std::vector<float>& us_alphas,
+                                std::vector<float> us_cif_peak, 
+                                std::vector<string>& char_list, 
+                                std::string &res_str, 
+                                std::vector<std::vector<float>> &timestamp_vec, 
+                                float begin_time, 
+                                float total_offset){
+    if (char_list.empty()) {
+        return ;
+    }
+
+    const float START_END_THRESHOLD = 5.0;
+    const float MAX_TOKEN_DURATION = 30.0;
+    const float TIME_RATE = 10.0 * 6 / 1000 / 3;
+    // 3 times upsampled, cif_peak is flattened into a 1D array
+    std::vector<float> cif_peak = us_cif_peak;
+    int num_frames = cif_peak.size();
+    if (char_list.back() == "</s>") {
+        char_list.pop_back();
+    }
+
+    vector<vector<float>> timestamp_list;
+    vector<string> new_char_list;
+    vector<float> fire_place;
+    // for bicif model trained with large data, cif2 actually fires when a character starts
+    // so treat the frames between two peaks as the duration of the former token
+    for (int i = 0; i < num_frames; i++) {
+        if (cif_peak[i] > 1.0 - 1e-4) {
+            fire_place.push_back(i + total_offset);
+        }
+    }
+    int num_peak = fire_place.size();
+    if(num_peak != (int)char_list.size() + 1){
+        float sum = std::accumulate(us_alphas.begin(), us_alphas.end(), 0.0f);
+        float scale = sum/((int)char_list.size() + 1);
+        cif_peak.clear();
+        sum = 0.0;
+        for(auto &alpha:us_alphas){
+            alpha = alpha/scale;
+            sum += alpha;
+            cif_peak.emplace_back(sum);
+            if(sum>=1.0 - 1e-4){
+                sum -=(1.0 - 1e-4);
+            }            
+        }
+
+        fire_place.clear();
+        for (int i = 0; i < num_frames; i++) {
+            if (cif_peak[i] > 1.0 - 1e-4) {
+                fire_place.push_back(i + total_offset);
+            }
+        }
+    }
+
+    // begin silence
+    if (fire_place[0] > START_END_THRESHOLD) {
+        new_char_list.push_back("<sil>");
+        timestamp_list.push_back({0.0, fire_place[0] * TIME_RATE});
+    }
+
+    // tokens timestamp
+    for (int i = 0; i < num_peak - 1; i++) {
+        new_char_list.push_back(char_list[i]);
+        if (i == num_peak - 2 || MAX_TOKEN_DURATION < 0 || fire_place[i + 1] - fire_place[i] < MAX_TOKEN_DURATION) {
+            timestamp_list.push_back({fire_place[i] * TIME_RATE, fire_place[i + 1] * TIME_RATE});
+        } else {
+            // cut the duration to token and sil of the 0-weight frames last long
+            float _split = fire_place[i] + MAX_TOKEN_DURATION;
+            timestamp_list.push_back({fire_place[i] * TIME_RATE, _split * TIME_RATE});
+            timestamp_list.push_back({_split * TIME_RATE, fire_place[i + 1] * TIME_RATE});
+            new_char_list.push_back("<sil>");
+        }
+    }
+
+    // tail token and end silence
+    if (num_frames - fire_place.back() > START_END_THRESHOLD) {
+        float _end = (num_frames + fire_place.back()) / 2.0;
+        timestamp_list.back()[1] = _end * TIME_RATE;
+        timestamp_list.push_back({_end * TIME_RATE, num_frames * TIME_RATE});
+        new_char_list.push_back("<sil>");
+    } else {
+        timestamp_list.back()[1] = num_frames * TIME_RATE;
+    }
+
+    if (begin_time) {  // add offset time in model with vad
+        for (auto& timestamp : timestamp_list) {
+            timestamp[0] += begin_time / 1000.0;
+            timestamp[1] += begin_time / 1000.0;
+        }
+    }
+
+    assert(new_char_list.size() == timestamp_list.size());
+
+    for (int i = 0; i < (int)new_char_list.size(); i++) {
+        res_str += new_char_list[i] + " " + to_string(timestamp_list[i][0]) + " " + to_string(timestamp_list[i][1]) + ";";
+    }
+
+    for (int i = 0; i < (int)new_char_list.size(); i++) {
+        if(new_char_list[i] != "<sil>"){
+            timestamp_vec.push_back(timestamp_list[i]);
+        }
+    }
 }
 
 vector<float> Paraformer::ApplyLfr(const std::vector<float> &in) 
@@ -369,7 +591,25 @@
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
         float* floatData = outputTensor[0].GetTensorMutableData<float>();
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
-        result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
+        // timestamp
+        if(outputTensor.size() == 4){
+            std::vector<int64_t> us_alphas_shape = outputTensor[2].GetTensorTypeAndShapeInfo().GetShape();
+            float* us_alphas_data = outputTensor[2].GetTensorMutableData<float>();
+            std::vector<float> us_alphas(us_alphas_shape[1]);
+            for (int i = 0; i < us_alphas_shape[1]; i++) {
+                us_alphas[i] = us_alphas_data[i];
+            }
+
+            std::vector<int64_t> us_peaks_shape = outputTensor[3].GetTensorTypeAndShapeInfo().GetShape();
+            float* us_peaks_data = outputTensor[3].GetTensorMutableData<float>();
+            std::vector<float> us_peaks(us_peaks_shape[1]);
+            for (int i = 0; i < us_peaks_shape[1]; i++) {
+                us_peaks[i] = us_peaks_data[i];
+            }
+            result = GreedySearch(floatData, *encoder_out_lens, outputShape[2], true, us_alphas, us_peaks);
+        }else{
+            result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
+        }
     }
     catch (std::exception const &e)
     {
diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h
index 16460bf..0dd55b5 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.h
+++ b/funasr/runtime/onnxruntime/src/paraformer.h
@@ -35,7 +35,10 @@
         void Reset();
         vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
         string Forward(float* din, int len, bool input_finished=true);
-        string GreedySearch( float* in, int n_len, int64_t token_nums);
+        string GreedySearch( float* in, int n_len, int64_t token_nums, bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
+        void TimestampOnnx(std::vector<float> &us_alphas, vector<float> us_cif_peak, vector<string>& char_list, std::string &res_str, 
+                           vector<vector<float>> &timestamp_list, float begin_time = 0.0, float total_offset = -1.5);
+        string PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>> &timestamp_list);
         string Rescoring();
 
         knf::FbankOptions fbank_opts_;
diff --git a/funasr/runtime/onnxruntime/src/util.cpp b/funasr/runtime/onnxruntime/src/util.cpp
index 755913c..e09caee 100644
--- a/funasr/runtime/onnxruntime/src/util.cpp
+++ b/funasr/runtime/onnxruntime/src/util.cpp
@@ -189,4 +189,14 @@
     return (extension == target);
 }
 
+std::vector<std::string> split(const std::string &s, char delim) {
+  std::vector<std::string> elems;
+  std::stringstream ss(s);
+  std::string item;
+  while(std::getline(ss, item, delim)) {
+    elems.push_back(item);
+  }
+  return elems;
+}
+
 } // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/util.h b/funasr/runtime/onnxruntime/src/util.h
index 8823a32..cd3c3c9 100644
--- a/funasr/runtime/onnxruntime/src/util.h
+++ b/funasr/runtime/onnxruntime/src/util.h
@@ -27,5 +27,6 @@
 string PathAppend(const string &p1, const string &p2);
 bool is_target_file(const std::string& filename, const std::string target);
 
+std::vector<std::string> split(const std::string &s, char delim);
 } // namespace funasr
 #endif
diff --git a/funasr/runtime/onnxruntime/src/vocab.cpp b/funasr/runtime/onnxruntime/src/vocab.cpp
index 70553df..dc40978 100644
--- a/funasr/runtime/onnxruntime/src/vocab.cpp
+++ b/funasr/runtime/onnxruntime/src/vocab.cpp
@@ -34,14 +34,12 @@
     }
 }
 
-string Vocab::Vector2String(vector<int> in)
+void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
 {
-    int i;
-    stringstream ss;
     for (auto it = in.begin(); it != in.end(); it++) {
-        ss << vocab[*it];
+        string word = vocab[*it];
+        preds.emplace_back(word);
     }
-    return ss.str();
 }
 
 int Str2Int(string str)
diff --git a/funasr/runtime/onnxruntime/src/vocab.h b/funasr/runtime/onnxruntime/src/vocab.h
index 6c4e523..9b462b7 100644
--- a/funasr/runtime/onnxruntime/src/vocab.h
+++ b/funasr/runtime/onnxruntime/src/vocab.h
@@ -11,7 +11,6 @@
 class Vocab {
   private:
     vector<string> vocab;
-    bool IsChinese(string ch);
     bool IsEnglish(string ch);
     void LoadVocabFromYaml(const char* filename);
 
@@ -19,7 +18,8 @@
     Vocab(const char *filename);
     ~Vocab();
     int Size();
-    string Vector2String(vector<int> in);
+    bool IsChinese(string ch);
+    void Vector2String(vector<int> in, std::vector<std::string> &preds);
     string Vector2StringV2(vector<int> in);
 };
 
diff --git a/funasr/runtime/python/onnxruntime/demo_paraformer_offline.py b/funasr/runtime/python/onnxruntime/demo_paraformer_offline.py
index 33229a3..bc8355b 100644
--- a/funasr/runtime/python/onnxruntime/demo_paraformer_offline.py
+++ b/funasr/runtime/python/onnxruntime/demo_paraformer_offline.py
@@ -2,6 +2,7 @@
 from pathlib import Path
 
 model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model_dir = "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
 model = Paraformer(model_dir, batch_size=1, quantize=True)
 # model = Paraformer(model_dir, batch_size=1, device_id=0)  # gpu
 

--
Gitblit v1.9.1