From 752fd0b8a98c18cb30f5199eec6158eb9a37f869 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 21 三月 2023 20:05:26 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/runtime/onnxruntime/include/Model.h          |    2 
 funasr/runtime/grpc/paraformer_server.cc            |   17 +-
 funasr/runtime/onnxruntime/readme.md                |    4 
 funasr/runtime/grpc/Readme.md                       |    4 
 funasr/runtime/onnxruntime/tester/CMakeLists.txt    |    3 
 funasr/runtime/onnxruntime/include/librapidasrapi.h |   29 ----
 funasr/datasets/large_datasets/build_dataloader.py  |   44 ++++++
 funasr/runtime/onnxruntime/src/paraformer_onnx.cpp  |   16 ++
 funasr/runtime/onnxruntime/tester/tester.cpp        |   20 +--
 funasr/utils/postprocess_utils.py                   |    7 
 funasr/datasets/large_datasets/dataset.py           |    3 
 funasr/runtime/onnxruntime/src/Model.cpp            |    5 
 funasr/runtime/onnxruntime/tester/tester_rtf.cpp    |   99 ++++++++++++++++
 funasr/runtime/onnxruntime/src/librapidasrapi.cpp   |   35 -----
 funasr/utils/timestamp_tools.py                     |   44 ++++--
 funasr/datasets/large_datasets/utils/tokenize.py    |    6 
 funasr/runtime/onnxruntime/src/paraformer_onnx.h    |    7 -
 funasr/bin/asr_inference_paraformer_vad_punc.py     |    2 
 funasr/runtime/grpc/paraformer_server.h             |    2 
 19 files changed, 229 insertions(+), 120 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 3f57751..ab3e1e3 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -292,6 +292,8 @@
 
                 # remove blank symbol id, which is assumed to be 0
                 token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+                if len(token_int) == 0:
+                    continue
 
                 # Change integer-ids to tokens
                 token = self.converter.ids2tokens(token_int)
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 0ad1889..156f608 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -1,10 +1,16 @@
 import logging
+from pathlib import Path
+from typing import Iterable
+from typing import List
+from typing import Union
 
-import yaml
-
+import sentencepiece as spm
 from torch.utils.data import DataLoader
+from typeguard import check_argument_types
+
 from funasr.datasets.large_datasets.dataset import Dataset
 from funasr.iterators.abs_iter_factory import AbsIterFactory
+from funasr.text.abs_tokenizer import AbsTokenizer
 
 
 def read_symbol_table(symbol_table_file):
@@ -21,6 +27,7 @@
             symbol_table[char] = i
     return symbol_table
 
+
 def load_seg_dict(seg_dict_file):
     seg_dict = {}
     assert isinstance(seg_dict_file, str)
@@ -33,8 +40,33 @@
             seg_dict[key] = " ".join(value)
     return seg_dict
 
+
+class SentencepiecesTokenizer(AbsTokenizer):
+    def __init__(self, model: Union[Path, str]):
+        assert check_argument_types()
+        self.model = str(model)
+        self.sp = None
+
+    def __repr__(self):
+        return f'{self.__class__.__name__}(model="{self.model}")'
+
+    def _build_sentence_piece_processor(self):
+        if self.sp is None:
+            self.sp = spm.SentencePieceProcessor()
+            self.sp.load(self.model)
+
+    def text2tokens(self, line: str) -> List[str]:
+        self._build_sentence_piece_processor()
+        return self.sp.EncodeAsPieces(line)
+
+    def tokens2text(self, tokens: Iterable[str]) -> str:
+        self._build_sentence_piece_processor()
+        return self.sp.DecodePieces(list(tokens))
+
+
 class ArkDataLoader(AbsIterFactory):
-    def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None, mode="train"):
+    def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None,
+                 bpemodel_file=None, mode="train"):
         symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
         if seg_dict_file is not None:
             seg_dict = load_seg_dict(seg_dict_file)
@@ -48,7 +80,11 @@
         self.frontend_conf = frontend_conf
         logging.info("dataloader config: {}".format(self.dataset_conf))
         batch_mode = self.dataset_conf.get("batch_mode", "padding")
-        self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict,
+        if bpemodel_file is not None:
+            bpe_tokenizer = SentencepiecesTokenizer(bpemodel_file)
+        else:
+            bpe_tokenizer = None
+        self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
                                self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
 
     def build_iter(self, epoch, shuffle=True):
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 1942371..b0e1b8f 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -158,6 +158,7 @@
             dict,
             seg_dict,
             punc_dict,
+            bpe_tokenizer,
             conf,
             frontend_conf,
             mode="train",
@@ -173,7 +174,7 @@
     dataset = FilterIterDataPipe(dataset, fn=filter_fn)
 
     if "text" in data_names:
-        vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict}
+        vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer}
         tokenize_fn = partial(tokenize, **vocab)
         dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
 
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
index a016e4e..3f20c5f 100644
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -28,13 +28,17 @@
 def tokenize(data,
              vocab=None,
              seg_dict=None,
-             punc_dict=None):
+             punc_dict=None,
+             bpe_tokenizer=None):
     assert "text" in data
     assert isinstance(vocab, dict)
     text = data["text"]
     token = []
     vad = -2
 
+    if bpe_tokenizer is not None:
+        text = bpe_tokenizer.text2tokens(text)
+
     if seg_dict is not None:
         assert isinstance(seg_dict, dict)
         txt = forward_segment("".join(text).lower(), seg_dict)
diff --git a/funasr/runtime/grpc/Readme.md b/funasr/runtime/grpc/Readme.md
index 80e55aa..2bcad08 100644
--- a/funasr/runtime/grpc/Readme.md
+++ b/funasr/runtime/grpc/Readme.md
@@ -44,8 +44,8 @@
 
 #### Step 4. Start grpc paraformer server
 ```
-Usage: ./cmake/build/paraformer_server port thread_num /path/to/model_file
-./cmake/build/paraformer_server 10108 4 /data/asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
+Usage: ./cmake/build/paraformer_server port thread_num /path/to/model_file quantize(true or false)
+./cmake/build/paraformer_server 10108 4 /data/asrmodel/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch false
 ```
 
 
diff --git a/funasr/runtime/grpc/paraformer_server.cc b/funasr/runtime/grpc/paraformer_server.cc
index e5814a5..69ce903 100644
--- a/funasr/runtime/grpc/paraformer_server.cc
+++ b/funasr/runtime/grpc/paraformer_server.cc
@@ -29,8 +29,8 @@
 using paraformer::Response;
 using paraformer::ASR;
 
-ASRServicer::ASRServicer(const char* model_path, int thread_num) {
-    AsrHanlde=RapidAsrInit(model_path, thread_num);
+ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
+    AsrHanlde=RapidAsrInit(model_path, thread_num, quantize);
     std::cout << "ASRServicer init" << std::endl;
     init_flag = 0;
 }
@@ -170,10 +170,10 @@
 }
 
 
-void RunServer(const std::string& port, int thread_num, const char* model_path) {
+void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
     std::string server_address;
     server_address = "0.0.0.0:" + port;
-    ASRServicer service(model_path, thread_num);
+    ASRServicer service(model_path, thread_num, quantize);
 
     ServerBuilder builder;
     builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
@@ -184,12 +184,15 @@
 }
 
 int main(int argc, char* argv[]) {
-    if (argc < 3)
+    if (argc < 5)
     {
-        printf("Usage: %s port thread_num /path/to/model_file\n", argv[0]);
+        printf("Usage: %s port thread_num /path/to/model_file quantize(true or false) \n", argv[0]);
         exit(-1);
     }
 
-    RunServer(argv[1], atoi(argv[2]), argv[3]);
+    // is quantize
+    bool quantize = false;
+    std::istringstream(argv[4]) >> std::boolalpha >> quantize;
+    RunServer(argv[1], atoi(argv[2]), argv[3], quantize);
     return 0;
 }
diff --git a/funasr/runtime/grpc/paraformer_server.h b/funasr/runtime/grpc/paraformer_server.h
index f356d94..e42e041 100644
--- a/funasr/runtime/grpc/paraformer_server.h
+++ b/funasr/runtime/grpc/paraformer_server.h
@@ -45,7 +45,7 @@
     std::unordered_map<std::string, std::string> client_transcription;
 
   public:
-    ASRServicer(const char* model_path, int thread_num);
+    ASRServicer(const char* model_path, int thread_num, bool quantize);
     void clear_states(const std::string& user);
     void clear_buffers(const std::string& user);
     void clear_transcriptions(const std::string& user);
diff --git a/funasr/runtime/onnxruntime/include/Model.h b/funasr/runtime/onnxruntime/include/Model.h
index 06267cb..6f45c38 100644
--- a/funasr/runtime/onnxruntime/include/Model.h
+++ b/funasr/runtime/onnxruntime/include/Model.h
@@ -13,5 +13,5 @@
     virtual std::string rescoring() = 0;
 };
 
-Model *create_model(const char *path,int nThread=0);
+Model *create_model(const char *path,int nThread=0,bool quantize=false);
 #endif
diff --git a/funasr/runtime/onnxruntime/include/librapidasrapi.h b/funasr/runtime/onnxruntime/include/librapidasrapi.h
index a83098f..918e574 100644
--- a/funasr/runtime/onnxruntime/include/librapidasrapi.h
+++ b/funasr/runtime/onnxruntime/include/librapidasrapi.h
@@ -1,33 +1,20 @@
 #pragma once
 
-
 #ifdef WIN32
-
-
 #ifdef _RPASR_API_EXPORT
-
 #define  _RAPIDASRAPI __declspec(dllexport)
 #else
 #define  _RAPIDASRAPI __declspec(dllimport)
 #endif
-	
-
 #else
-#define _RAPIDASRAPI  
+#define _RAPIDASRAPI
 #endif
 
-
-
-
-
 #ifndef _WIN32
-
 #define RPASR_CALLBCK_PREFIX __attribute__((__stdcall__))
-
 #else
 #define RPASR_CALLBCK_PREFIX __stdcall
 #endif
-	
 
 #ifdef __cplusplus 
 
@@ -35,15 +22,12 @@
 #endif
 
 typedef void* RPASR_HANDLE;
-
 typedef void* RPASR_RESULT;
-
 typedef unsigned char RPASR_BOOL;
 
 #define RPASR_TRUE 1
 #define RPASR_FALSE 0
 #define QM_DEFAULT_THREAD_NUM  4
-
 
 typedef enum
 {
@@ -55,7 +39,6 @@
 }RPASR_MODE;
 
 typedef enum {
-
 	RPASR_MODEL_PADDLE = 0,
 	RPASR_MODEL_PADDLE_2 = 1,
 	RPASR_MODEL_K2 = 2,
@@ -63,17 +46,15 @@
 
 }RPASR_MODEL_TYPE;
 
-
 typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
 	
-	// APIs for qmasr
-
-_RAPIDASRAPI RPASR_HANDLE  RapidAsrInit(const char* szModelDir, int nThread);
-
+// APIs for qmasr
+_RAPIDASRAPI RPASR_HANDLE  RapidAsrInit(const char* szModelDir, int nThread, bool quantize);
 
 
 // if not give a fnCallback ,it should be NULL 
 _RAPIDASRAPI RPASR_RESULT	RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
+
 _RAPIDASRAPI RPASR_RESULT	RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
 
 _RAPIDASRAPI RPASR_RESULT	RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback);
@@ -83,8 +64,8 @@
 _RAPIDASRAPI const char*	RapidAsrGetResult(RPASR_RESULT Result,int nIndex);
 
 _RAPIDASRAPI const int		RapidAsrGetRetNumber(RPASR_RESULT Result);
-_RAPIDASRAPI void			RapidAsrFreeResult(RPASR_RESULT Result);
 
+_RAPIDASRAPI void			RapidAsrFreeResult(RPASR_RESULT Result);
 
 _RAPIDASRAPI void			RapidAsrUninit(RPASR_HANDLE Handle);
 
diff --git a/funasr/runtime/onnxruntime/readme.md b/funasr/runtime/onnxruntime/readme.md
index 41c63c6..16d9dc7 100644
--- a/funasr/runtime/onnxruntime/readme.md
+++ b/funasr/runtime/onnxruntime/readme.md
@@ -16,9 +16,9 @@
 
 ###  杩愯绋嬪簭
 
-tester  /path/to/models/dir /path/to/wave/file
+tester  /path/to/models/dir /path/to/wave/file quantize(true or false)
 
- 渚嬪锛� tester /data/models  /data/test.wav
+ 渚嬪锛� tester /data/models  /data/test.wav false
 
 /data/models 闇�瑕佸寘鎷涓嬩袱涓枃浠讹細 model.onnx 鍜寁ocab.txt
 
diff --git a/funasr/runtime/onnxruntime/src/Model.cpp b/funasr/runtime/onnxruntime/src/Model.cpp
index ddd4fd0..7ddb635 100644
--- a/funasr/runtime/onnxruntime/src/Model.cpp
+++ b/funasr/runtime/onnxruntime/src/Model.cpp
@@ -1,11 +1,10 @@
 #include "precomp.h"
 
-Model *create_model(const char *path,int nThread)
+Model *create_model(const char *path, int nThread, bool quantize)
 {
     Model *mm;
 
-
-    mm = new paraformer::ModelImp(path, nThread);
+    mm = new paraformer::ModelImp(path, nThread, quantize);
 
     return mm;
 }
diff --git a/funasr/runtime/onnxruntime/src/librapidasrapi.cpp b/funasr/runtime/onnxruntime/src/librapidasrapi.cpp
index f5f9d66..62f47a5 100644
--- a/funasr/runtime/onnxruntime/src/librapidasrapi.cpp
+++ b/funasr/runtime/onnxruntime/src/librapidasrapi.cpp
@@ -4,24 +4,16 @@
 extern "C" {
 #endif
 
-
 	// APIs for qmasr
-	_RAPIDASRAPI RPASR_HANDLE  RapidAsrInit(const char* szModelDir, int nThreadNum)
+	_RAPIDASRAPI RPASR_HANDLE  RapidAsrInit(const char* szModelDir, int nThreadNum, bool quantize)
 	{
-
-
-		Model* mm = create_model(szModelDir, nThreadNum); 
-
+		Model* mm = create_model(szModelDir, nThreadNum, quantize);
 		return mm;
 	}
 
-
 	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
-
-
 		Model* pRecogObj = (Model*)handle;
-
 		if (!pRecogObj)
 			return nullptr;
 
@@ -46,15 +38,12 @@
 				fnCallback(nStep, nTotal);
 		}
 
-
 		return pResult;
 	}
 
 	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
-
 		Model* pRecogObj = (Model*)handle;
-
 		if (!pRecogObj)
 			return nullptr;
 
@@ -79,16 +68,12 @@
 				fnCallback(nStep, nTotal);
 		}
 
-
 		return pResult;
-
 	}
 
 	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
-
 		Model* pRecogObj = (Model*)handle;
-
 		if (!pRecogObj)
 			return nullptr;
 
@@ -113,15 +98,12 @@
 				fnCallback(nStep, nTotal);
 		}
 
-
 		return pResult;
-
 	}
 
 	_RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback)
 	{
 		Model* pRecogObj = (Model*)handle;
-
 		if (!pRecogObj)
 			return nullptr;
 
@@ -146,9 +128,6 @@
 				fnCallback(nStep, nTotal);
 		}
 	
-	
-
-
 		return pResult;
 	}
 
@@ -158,7 +137,6 @@
 			return 0;
 
 		return 1;
-		
 	}
 
 
@@ -168,7 +146,6 @@
 			return 0.0f;
 
 		return ((RPASR_RECOG_RESULT*)Result)->snippet_time;
-
 	}
 
 	_RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex)
@@ -178,33 +155,25 @@
 			return nullptr;
 
 		return pResult->msg.c_str();
-	
 	}
 
 	_RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result)
 	{
-
 		if (Result)
 		{
 			delete (RPASR_RECOG_RESULT*)Result;
-
 		}
 	}
 
 	_RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE handle)
 	{
-
 		Model* pRecogObj = (Model*)handle;
-
 
 		if (!pRecogObj)
 			return;
 
 		delete pRecogObj;
-
 	}
-
-
 
 #ifdef __cplusplus 
 
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index 46b5211..8eb0e89 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -3,14 +3,22 @@
 using namespace std;
 using namespace paraformer;
 
-ModelImp::ModelImp(const char* path,int nNumThread)
+ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
 {
-    string model_path = pathAppend(path, "model.onnx");
-    string vocab_path = pathAppend(path, "vocab.txt");
+    string model_path;
+    string vocab_path;
+    if(quantize)
+    {
+        model_path = pathAppend(path, "model_quant.onnx");
+    }else{
+        model_path = pathAppend(path, "model.onnx");
+    }
+    vocab_path = pathAppend(path, "vocab.txt");
 
     fe = new FeatureExtract(3);
 
-    sessionOptions.SetInterOpNumThreads(nNumThread);
+    //sessionOptions.SetInterOpNumThreads(1);
+    sessionOptions.SetIntraOpNumThreads(nNumThread);
     sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
 
 #ifdef _WIN32
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.h b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
index ebbbb51..db00842 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.h
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -4,10 +4,6 @@
 #ifndef PARAFORMER_MODELIMP_H
 #define PARAFORMER_MODELIMP_H
 
-
-
-
-
 namespace paraformer {
 
     class ModelImp : public Model {
@@ -19,7 +15,6 @@
         void apply_lfr(Tensor<float>*& din);
         void apply_cmvn(Tensor<float>* din);
 
-        
         string greedy_search( float* in, int nLen);
 
 #ifdef _WIN_X86
@@ -39,7 +34,7 @@
         //string m_strOutputName, m_strOutputNameLen;
 
     public:
-        ModelImp(const char* path, int nNumThread=0);
+        ModelImp(const char* path, int nNumThread=0, bool quantize=false);
         ~ModelImp();
         void reset();
         string forward_chunk(float* din, int len, int flag);
diff --git a/funasr/runtime/onnxruntime/tester/CMakeLists.txt b/funasr/runtime/onnxruntime/tester/CMakeLists.txt
index d794271..f66319d 100644
--- a/funasr/runtime/onnxruntime/tester/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/tester/CMakeLists.txt
@@ -13,8 +13,11 @@
 
 include_directories(${CMAKE_SOURCE_DIR}/include)
 set(EXECNAME "tester")
+set(EXECNAMERTF "tester_rtf")
 
 add_executable(${EXECNAME} "tester.cpp")
 target_link_libraries(${EXECNAME} PUBLIC ${EXTRA_LIBS})
 
+add_executable(${EXECNAMERTF} "tester_rtf.cpp")
+target_link_libraries(${EXECNAMERTF} PUBLIC ${EXTRA_LIBS})
 
diff --git a/funasr/runtime/onnxruntime/tester/tester.cpp b/funasr/runtime/onnxruntime/tester/tester.cpp
index ba5c61c..2bba39a 100644
--- a/funasr/runtime/onnxruntime/tester/tester.cpp
+++ b/funasr/runtime/onnxruntime/tester/tester.cpp
@@ -9,41 +9,40 @@
 
 #include <iostream>
 #include <fstream>
+#include <sstream>
 using namespace std;
 
 int main(int argc, char *argv[])
 {
 
-    if (argc < 2)
+    if (argc < 4)
     {
-        printf("Usage: %s /path/to/model_dir /path/to/wav/file", argv[0]);
+        printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) \n", argv[0]);
         exit(-1);
     }
     struct timeval start, end;
     gettimeofday(&start, NULL);
     int nThreadNum = 4;
-    RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum);
+    // is quantize
+    bool quantize = false;
+    istringstream(argv[3]) >> boolalpha >> quantize;
+    RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
 
     if (!AsrHanlde)
     {
         printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
         exit(-1);
     }
-    
- 
 
     gettimeofday(&end, NULL);
     long seconds = (end.tv_sec - start.tv_sec);
     long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
     printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
 
-
-
     gettimeofday(&start, NULL);
     float snippet_time = 0.0f;
 
-
-     RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
+    RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
 
     gettimeofday(&end, NULL);
    
@@ -61,7 +60,6 @@
     {
         cout <<"no return data!";
     }
- 
  
     //char* buff = nullptr;
     //int len = 0;
@@ -101,13 +99,11 @@
     //   
     //delete[]buff;
     //}
-
  
     printf("Audio length %lfs.\n", (double)snippet_time);
     seconds = (end.tv_sec - start.tv_sec);
     long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
     printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000);
-
     printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000));
 
     RapidAsrUninit(AsrHanlde);
diff --git a/funasr/runtime/onnxruntime/tester/tester_rtf.cpp b/funasr/runtime/onnxruntime/tester/tester_rtf.cpp
new file mode 100644
index 0000000..9651900
--- /dev/null
+++ b/funasr/runtime/onnxruntime/tester/tester_rtf.cpp
@@ -0,0 +1,99 @@
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include "librapidasrapi.h"
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <vector>
+using namespace std;
+
+int main(int argc, char *argv[])
+{
+
+    if (argc < 4)
+    {
+        printf("Usage: %s /path/to/model_dir /path/to/wav.scp quantize(true or false) \n", argv[0]);
+        exit(-1);
+    }
+
+    // read wav.scp
+    vector<string> wav_list;
+    ifstream in(argv[2]);
+    if (!in.is_open()) {
+        printf("Failed to open file: %s", argv[2]);
+        return 0;
+    }
+    string line;
+    while(getline(in, line))
+    {
+        istringstream iss(line);
+        string column1, column2;
+        iss >> column1 >> column2;
+        wav_list.push_back(column2); 
+    }
+    in.close();
+
+    // model init
+    struct timeval start, end;
+    gettimeofday(&start, NULL);
+    int nThreadNum = 1;
+    // is quantize
+    bool quantize = false;
+    istringstream(argv[3]) >> boolalpha >> quantize;
+
+    RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
+    if (!AsrHanlde)
+    {
+        printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
+        exit(-1);
+    }
+    gettimeofday(&end, NULL);
+    long seconds = (end.tv_sec - start.tv_sec);
+    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+    printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
+
+    // warm up
+    for (size_t i = 0; i < 30; i++)
+    {
+        RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
+    }
+
+    // forward
+    float snippet_time = 0.0f;
+    float total_length = 0.0f;
+    long total_time = 0.0f;
+    
+    for (size_t i = 0; i < wav_list.size(); i++)
+    {
+        gettimeofday(&start, NULL);
+        RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
+        gettimeofday(&end, NULL);
+        seconds = (end.tv_sec - start.tv_sec);
+        long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+        total_time += taking_micros;
+
+        if(Result){
+            string msg = RapidAsrGetResult(Result, 0);
+            printf("Result: %s \n", msg);
+
+            snippet_time = RapidAsrGetRetSnippetTime(Result);
+            total_length += snippet_time;
+            RapidAsrFreeResult(Result);
+        }else{
+            cout <<"No return data!";
+        }
+    }
+
+    printf("total_time_wav %ld ms.\n", (long)(total_length * 1000));
+    printf("total_time_comput %ld ms.\n", total_time / 1000);
+    printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
+
+    RapidAsrUninit(AsrHanlde);
+    return 0;
+}
diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py
index 2475548..40756d8 100644
--- a/funasr/utils/postprocess_utils.py
+++ b/funasr/utils/postprocess_utils.py
@@ -106,17 +106,18 @@
         if num in abbr_begin:
             if time_stamp is not None:
                 begin = time_stamp[ts_nums[num]][0]
-            word_lists.append(words[num].upper())
+            abbr_word = words[num].upper()
             num += 1
             while num < words_size:
                 if num in abbr_end:
-                    word_lists.append(words[num].upper())
+                    abbr_word += words[num].upper()
                     last_num = num
                     break
                 else:
                     if words[num].encode('utf-8').isalpha():
-                        word_lists.append(words[num].upper())
+                        abbr_word += words[num].upper()
                 num += 1
+            word_lists.append(abbr_word)
             if time_stamp is not None:
                 end = time_stamp[ts_nums[num]][1]
                 ts_lists.append([begin, end])
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 423110c..87cc49e 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -1,3 +1,5 @@
+from itertools import zip_longest
+
 import torch
 import copy
 import codecs
@@ -17,7 +19,7 @@
                        sil_in_str=True
                        ):
     if not len(char_list):
-        return []
+        return "", []
     START_END_THRESHOLD = 5
     MAX_TOKEN_DURATION = 12
     TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
@@ -87,6 +89,7 @@
         return res
     if len(text_postprocessed) == 0:
         return res
+
     if punc_id_list is None or len(punc_id_list) == 0:
         res.append({
             'text': text_postprocessed.split(),
@@ -95,36 +98,45 @@
         })
         return res
     if len(punc_id_list) != len(time_stamp_postprocessed):
-        res.append({
-            'text': text_postprocessed.split(),
-            "start": time_stamp_postprocessed[0][0],
-            "end": time_stamp_postprocessed[-1][1]
-        })
-        return res
-
+        print("  warning length mistach!!!!!!")
     sentence_text = ''
     sentence_start = time_stamp_postprocessed[0][0]
+    sentence_end = time_stamp_postprocessed[0][1]
     texts = text_postprocessed.split()
-    for i in range(len(punc_id_list)):
-        sentence_text += texts[i]
-        if punc_id_list[i] == 2:
+    punc_stamp_text_list = list(zip_longest(punc_id_list, time_stamp_postprocessed, texts, fillvalue=None))
+    for punc_stamp_text in punc_stamp_text_list:
+        punc_id, time_stamp, text = punc_stamp_text
+        sentence_text += text if text is not None else ''
+        punc_id = int(punc_id) if punc_id is not None else 1
+        sentence_end = time_stamp[1] if time_stamp is not None else sentence_end
+
+        if punc_id == 2:
             sentence_text += ','
             res.append({
                 'text': sentence_text,
                 "start": sentence_start,
-                "end": time_stamp_postprocessed[i][1]
+                "end": sentence_end
             })
             sentence_text = ''
-            sentence_start = time_stamp_postprocessed[i][1]
-        elif punc_id_list[i] == 3:
+            sentence_start = sentence_end
+        elif punc_id == 3:
             sentence_text += '.'
             res.append({
                 'text': sentence_text,
                 "start": sentence_start,
-                "end": time_stamp_postprocessed[i][1]
+                "end": sentence_end
             })
             sentence_text = ''
-            sentence_start = time_stamp_postprocessed[i][1]
+            sentence_start = sentence_end
+        elif punc_id == 4:
+            sentence_text += '?'
+            res.append({
+                'text': sentence_text,
+                "start": sentence_start,
+                "end": sentence_end
+            })
+            sentence_text = ''
+            sentence_start = sentence_end
     return res
 
 

--
Gitblit v1.9.1