From 0535db1c65180cfb4da046c5d865c764e6445746 Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期一, 24 四月 2023 10:57:53 +0800
Subject: [PATCH] rename CT-transformer

---
 funasr/runtime/onnxruntime/include/Model.h             |    2 
 funasr/runtime/onnxruntime/src/Model.cpp               |    4 +-
 funasr/runtime/onnxruntime/readme.md                   |    4 +-
 funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp |   10 +++--
 funasr/runtime/onnxruntime/src/libfunasrapi.cpp        |   26 +++++++++---
 funasr/runtime/onnxruntime/src/CT-transformer.h        |    1 
 funasr/runtime/onnxruntime/src/CT-transformer.cpp      |    9 +++-
 funasr/runtime/onnxruntime/src/paraformer_onnx.h       |    3 +
 funasr/runtime/onnxruntime/src/paraformer_onnx.cpp     |    5 +-
 funasr/runtime/onnxruntime/src/precomp.h               |    2 
 funasr/runtime/onnxruntime/include/libfunasrapi.h      |   10 ++--
 11 files changed, 48 insertions(+), 28 deletions(-)

diff --git a/funasr/runtime/onnxruntime/include/Model.h b/funasr/runtime/onnxruntime/include/Model.h
index f92789f..2d7873f 100644
--- a/funasr/runtime/onnxruntime/include/Model.h
+++ b/funasr/runtime/onnxruntime/include/Model.h
@@ -15,5 +15,5 @@
     virtual std::string AddPunc(const char* szInput)=0;
 };
 
-Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false);
+Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false, bool use_punc=false);
 #endif
diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h
index 8d8ebd2..a967ad2 100644
--- a/funasr/runtime/onnxruntime/include/libfunasrapi.h
+++ b/funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -49,17 +49,17 @@
 typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
 	
 // APIs for funasr
-_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThread, bool quantize=false, bool use_vad=false);
+_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThread, bool quantize=false, bool use_vad=false, bool use_punc=false);
 
 
 // if not give a fnCallback ,it should be NULL 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
 
 _FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT Result,int nIndex);
 
diff --git a/funasr/runtime/onnxruntime/readme.md b/funasr/runtime/onnxruntime/readme.md
index 81030ba..6886d58 100644
--- a/funasr/runtime/onnxruntime/readme.md
+++ b/funasr/runtime/onnxruntime/readme.md
@@ -59,12 +59,12 @@
 ## Run the demo
 
 ```shell
-funasr-onnx-offline /path/models_dir /path/wave_file quantize(true or false) use_vad(true or false)
+funasr-onnx-offline /path/models_dir /path/wave_file quantize(true or false) use_vad(true or false) use_punc(true or false)
 ```
 
 The structure of /path/models_dir
 ```
-config.yaml, am.mvn, model.onnx(or model_quant.onnx), (vad_model.onnx, vad.mvn if you use vad)
+config.yaml, am.mvn, model.onnx(or model_quant.onnx), (vad_model.onnx, vad.mvn if you use vad), (punc_model.onnx, punc.yaml if you use vad)
 ```
 
 
diff --git a/funasr/runtime/onnxruntime/src/punc_infer.cpp b/funasr/runtime/onnxruntime/src/CT-transformer.cpp
similarity index 98%
rename from funasr/runtime/onnxruntime/src/punc_infer.cpp
rename to funasr/runtime/onnxruntime/src/CT-transformer.cpp
index 8dbb49d..5698703 100644
--- a/funasr/runtime/onnxruntime/src/punc_infer.cpp
+++ b/funasr/runtime/onnxruntime/src/CT-transformer.cpp
@@ -10,14 +10,19 @@
 	string strModelPath = pathAppend(sz_model_dir, PUNC_MODEL_FILE);
 	string strYamlPath = pathAppend(sz_model_dir, PUNC_YAML_FILE);
 
+    try{
 #ifdef _WIN32
 	std::wstring detPath = strToWstr(strModelPath);
     m_session = std::make_unique<Ort::Session>(env_, detPath.c_str(), session_options);
 #else
     m_session = std::make_unique<Ort::Session>(env_, strModelPath.c_str(), session_options);
 #endif
-    // read inputnames outputnames
-    vector<string> m_strInputNames, m_strOutputNames;
+    }
+    catch(exception e)
+    {
+        printf(e.what());
+    }
+    // read inputnames outputnamess
     string strName;
     getInputName(m_session.get(), strName);
     m_strInputNames.push_back(strName.c_str());
diff --git a/funasr/runtime/onnxruntime/src/punc_infer.h b/funasr/runtime/onnxruntime/src/CT-transformer.h
similarity index 92%
rename from funasr/runtime/onnxruntime/src/punc_infer.h
rename to funasr/runtime/onnxruntime/src/CT-transformer.h
index e4ef0aa..77972c7 100644
--- a/funasr/runtime/onnxruntime/src/punc_infer.h
+++ b/funasr/runtime/onnxruntime/src/CT-transformer.h
@@ -10,6 +10,7 @@
 private:
 
 	CTokenizer m_tokenizer;
+	vector<string> m_strInputNames, m_strOutputNames;
 	vector<const char*> m_szInputNames;
 	vector<const char*> m_szOutputNames;
 
diff --git a/funasr/runtime/onnxruntime/src/Model.cpp b/funasr/runtime/onnxruntime/src/Model.cpp
index 2f864a9..bd1ba3c 100644
--- a/funasr/runtime/onnxruntime/src/Model.cpp
+++ b/funasr/runtime/onnxruntime/src/Model.cpp
@@ -1,10 +1,10 @@
 #include "precomp.h"
 
-Model *create_model(const char *path, int nThread, bool quantize, bool use_vad)
+Model *create_model(const char *path, int nThread, bool quantize, bool use_vad, bool use_punc)
 {
     Model *mm;
 
-    mm = new paraformer::ModelImp(path, nThread, quantize, use_vad);
+    mm = new paraformer::ModelImp(path, nThread, quantize, use_vad, use_punc);
 
     return mm;
 }
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
index a2684c3..53e6c1d 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
+++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
@@ -11,9 +11,9 @@
 
 int main(int argc, char *argv[])
 {
-    if (argc < 5)
+    if (argc < 6)
     {
-        printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) use_vad(true or false) \n", argv[0]);
+        printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) use_vad(true or false) use_punc(true or false)\n", argv[0]);
         exit(-1);
     }
     struct timeval start, end;
@@ -22,9 +22,11 @@
     // is quantize
     bool quantize = false;
     bool use_vad = false;
+    bool use_punc = false;
     istringstream(argv[3]) >> boolalpha >> quantize;
     istringstream(argv[4]) >> boolalpha >> use_vad;
-    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize, use_vad);
+    istringstream(argv[5]) >> boolalpha >> use_punc;
+    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize, use_vad, use_punc);
 
     if (!AsrHanlde)
     {
@@ -38,7 +40,7 @@
     printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
 
     gettimeofday(&start, NULL);
-    FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL, use_vad);
+    FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL, use_vad, use_punc);
     gettimeofday(&end, NULL);
 
     float snippet_time = 0.0f;
diff --git a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
index 0adef89..60414bf 100644
--- a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
+++ b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -5,13 +5,13 @@
 #endif
 
 	// APIs for funasr
-	_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThreadNum, bool quantize, bool use_vad)
+	_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThreadNum, bool quantize, bool use_vad, bool use_punc)
 	{
-		Model* mm = create_model(szModelDir, nThreadNum, quantize, use_vad);
+		Model* mm = create_model(szModelDir, nThreadNum, quantize, use_vad, use_punc);
 		return mm;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -39,11 +39,15 @@
 			if (fnCallback)
 				fnCallback(nStep, nTotal);
 		}
+		if(use_punc){
+			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
+			pResult->msg = punc_res;
+		}
 
 		return pResult;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -70,11 +74,15 @@
 			if (fnCallback)
 				fnCallback(nStep, nTotal);
 		}
+		if(use_punc){
+			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
+			pResult->msg = punc_res;
+		}
 
 		return pResult;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -101,11 +109,15 @@
 			if (fnCallback)
 				fnCallback(nStep, nTotal);
 		}
+		if(use_punc){
+			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
+			pResult->msg = punc_res;
+		}
 
 		return pResult;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -133,7 +145,7 @@
 			if (fnCallback)
 				fnCallback(nStep, nTotal);
 		}
-		if(true){
+		if(use_punc){
 			string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
 			pResult->msg = punc_res;
 		}
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index 69d1554..289eab1 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -3,7 +3,7 @@
 using namespace std;
 using namespace paraformer;
 
-ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad)
+ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad, bool use_punc)
 :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
     string model_path;
     string cmvn_path;
@@ -18,7 +18,7 @@
     }
 
     // PUNC model
-    if(true){
+    if(use_punc){
         puncHandle = make_unique<CTTransformer>(path, nNumThread);
     }
 
@@ -55,7 +55,6 @@
     m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
 #endif
 
-    vector<string> m_strInputNames, m_strOutputNames;
     string strName;
     getInputName(m_session.get(), strName);
     m_strInputNames.push_back(strName.c_str());
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.h b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
index cde2937..9008d10 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.h
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -33,11 +33,12 @@
         Ort::Env env_;
         Ort::SessionOptions sessionOptions;
 
+        vector<string> m_strInputNames, m_strOutputNames;
         vector<const char*> m_szInputNames;
         vector<const char*> m_szOutputNames;
 
     public:
-        ModelImp(const char* path, int nNumThread=0, bool quantize=false, bool use_vad=false);
+        ModelImp(const char* path, int nNumThread=0, bool quantize=false, bool use_vad=false, bool use_punc=false);
         ~ModelImp();
         void reset();
         vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index 40d8928..7bfa1a6 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -29,7 +29,7 @@
 #include "commonfunc.h"
 #include "predefine_coe.h"
 #include "tokenizer.h"
-#include "punc_infer.h"
+#include "CT-transformer.h"
 #include "FsmnVad.h"
 #include "e2e_vad.h"
 #include "Vocab.h"

--
Gitblit v1.9.1