From 0535db1c65180cfb4da046c5d865c764e6445746 Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期一, 24 四月 2023 10:57:53 +0800
Subject: [PATCH] rename CT-transformer
---
funasr/runtime/onnxruntime/include/Model.h | 2
funasr/runtime/onnxruntime/src/Model.cpp | 4 +-
funasr/runtime/onnxruntime/readme.md | 4 +-
funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp | 10 +++--
funasr/runtime/onnxruntime/src/libfunasrapi.cpp | 26 +++++++++---
funasr/runtime/onnxruntime/src/CT-transformer.h | 1
funasr/runtime/onnxruntime/src/CT-transformer.cpp | 9 +++-
funasr/runtime/onnxruntime/src/paraformer_onnx.h | 3 +
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp | 5 +-
funasr/runtime/onnxruntime/src/precomp.h | 2
funasr/runtime/onnxruntime/include/libfunasrapi.h | 10 ++--
11 files changed, 48 insertions(+), 28 deletions(-)
diff --git a/funasr/runtime/onnxruntime/include/Model.h b/funasr/runtime/onnxruntime/include/Model.h
index f92789f..2d7873f 100644
--- a/funasr/runtime/onnxruntime/include/Model.h
+++ b/funasr/runtime/onnxruntime/include/Model.h
@@ -15,5 +15,5 @@
virtual std::string AddPunc(const char* szInput)=0;
};
-Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false);
+Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false, bool use_punc=false);
#endif
diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h
index 8d8ebd2..a967ad2 100644
--- a/funasr/runtime/onnxruntime/include/libfunasrapi.h
+++ b/funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -49,17 +49,17 @@
typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
// APIs for funasr
-_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThread, bool quantize=false, bool use_vad=false);
+_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThread, bool quantize=false, bool use_vad=false, bool use_punc=false);
// if not give a fnCallback ,it should be NULL
-_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
-_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
-_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
-_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
+_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false, bool use_punc=false);
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex);
diff --git a/funasr/runtime/onnxruntime/readme.md b/funasr/runtime/onnxruntime/readme.md
index 81030ba..6886d58 100644
--- a/funasr/runtime/onnxruntime/readme.md
+++ b/funasr/runtime/onnxruntime/readme.md
@@ -59,12 +59,12 @@
## Run the demo
```shell
-funasr-onnx-offline /path/models_dir /path/wave_file quantize(true or false) use_vad(true or false)
+funasr-onnx-offline /path/models_dir /path/wave_file quantize(true or false) use_vad(true or false) use_punc(true or false)
```
The structure of /path/models_dir
```
-config.yaml, am.mvn, model.onnx(or model_quant.onnx), (vad_model.onnx, vad.mvn if you use vad)
+config.yaml, am.mvn, model.onnx(or model_quant.onnx), (vad_model.onnx, vad.mvn if you use vad), (punc_model.onnx, punc.yaml if you use vad)
```
diff --git a/funasr/runtime/onnxruntime/src/punc_infer.cpp b/funasr/runtime/onnxruntime/src/CT-transformer.cpp
similarity index 98%
rename from funasr/runtime/onnxruntime/src/punc_infer.cpp
rename to funasr/runtime/onnxruntime/src/CT-transformer.cpp
index 8dbb49d..5698703 100644
--- a/funasr/runtime/onnxruntime/src/punc_infer.cpp
+++ b/funasr/runtime/onnxruntime/src/CT-transformer.cpp
@@ -10,14 +10,19 @@
string strModelPath = pathAppend(sz_model_dir, PUNC_MODEL_FILE);
string strYamlPath = pathAppend(sz_model_dir, PUNC_YAML_FILE);
+ try{
#ifdef _WIN32
std::wstring detPath = strToWstr(strModelPath);
m_session = std::make_unique<Ort::Session>(env_, detPath.c_str(), session_options);
#else
m_session = std::make_unique<Ort::Session>(env_, strModelPath.c_str(), session_options);
#endif
- // read inputnames outputnames
- vector<string> m_strInputNames, m_strOutputNames;
+ }
+ catch(exception e)
+ {
+ printf(e.what());
+ }
+ // read inputnames outputnamess
string strName;
getInputName(m_session.get(), strName);
m_strInputNames.push_back(strName.c_str());
diff --git a/funasr/runtime/onnxruntime/src/punc_infer.h b/funasr/runtime/onnxruntime/src/CT-transformer.h
similarity index 92%
rename from funasr/runtime/onnxruntime/src/punc_infer.h
rename to funasr/runtime/onnxruntime/src/CT-transformer.h
index e4ef0aa..77972c7 100644
--- a/funasr/runtime/onnxruntime/src/punc_infer.h
+++ b/funasr/runtime/onnxruntime/src/CT-transformer.h
@@ -10,6 +10,7 @@
private:
CTokenizer m_tokenizer;
+ vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
diff --git a/funasr/runtime/onnxruntime/src/Model.cpp b/funasr/runtime/onnxruntime/src/Model.cpp
index 2f864a9..bd1ba3c 100644
--- a/funasr/runtime/onnxruntime/src/Model.cpp
+++ b/funasr/runtime/onnxruntime/src/Model.cpp
@@ -1,10 +1,10 @@
#include "precomp.h"
-Model *create_model(const char *path, int nThread, bool quantize, bool use_vad)
+Model *create_model(const char *path, int nThread, bool quantize, bool use_vad, bool use_punc)
{
Model *mm;
- mm = new paraformer::ModelImp(path, nThread, quantize, use_vad);
+ mm = new paraformer::ModelImp(path, nThread, quantize, use_vad, use_punc);
return mm;
}
diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp b/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
index a2684c3..53e6c1d 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
+++ b/funasr/runtime/onnxruntime/src/funasr-onnx-offline.cpp
@@ -11,9 +11,9 @@
int main(int argc, char *argv[])
{
- if (argc < 5)
+ if (argc < 6)
{
- printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) use_vad(true or false) \n", argv[0]);
+ printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) use_vad(true or false) use_punc(true or false)\n", argv[0]);
exit(-1);
}
struct timeval start, end;
@@ -22,9 +22,11 @@
// is quantize
bool quantize = false;
bool use_vad = false;
+ bool use_punc = false;
istringstream(argv[3]) >> boolalpha >> quantize;
istringstream(argv[4]) >> boolalpha >> use_vad;
- FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize, use_vad);
+ istringstream(argv[5]) >> boolalpha >> use_punc;
+ FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize, use_vad, use_punc);
if (!AsrHanlde)
{
@@ -38,7 +40,7 @@
printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
gettimeofday(&start, NULL);
- FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL, use_vad);
+ FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL, use_vad, use_punc);
gettimeofday(&end, NULL);
float snippet_time = 0.0f;
diff --git a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
index 0adef89..60414bf 100644
--- a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
+++ b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -5,13 +5,13 @@
#endif
// APIs for funasr
- _FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThreadNum, bool quantize, bool use_vad)
+ _FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThreadNum, bool quantize, bool use_vad, bool use_punc)
{
- Model* mm = create_model(szModelDir, nThreadNum, quantize, use_vad);
+ Model* mm = create_model(szModelDir, nThreadNum, quantize, use_vad, use_punc);
return mm;
}
- _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+ _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
@@ -39,11 +39,15 @@
if (fnCallback)
fnCallback(nStep, nTotal);
}
+ if(use_punc){
+ string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
+ pResult->msg = punc_res;
+ }
return pResult;
}
- _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+ _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
@@ -70,11 +74,15 @@
if (fnCallback)
fnCallback(nStep, nTotal);
}
+ if(use_punc){
+ string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
+ pResult->msg = punc_res;
+ }
return pResult;
}
- _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+ _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
@@ -101,11 +109,15 @@
if (fnCallback)
fnCallback(nStep, nTotal);
}
+ if(use_punc){
+ string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
+ pResult->msg = punc_res;
+ }
return pResult;
}
- _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
+ _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad, bool use_punc)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
@@ -133,7 +145,7 @@
if (fnCallback)
fnCallback(nStep, nTotal);
}
- if(true){
+ if(use_punc){
string punc_res = pRecogObj->AddPunc((pResult->msg).c_str());
pResult->msg = punc_res;
}
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index 69d1554..289eab1 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -3,7 +3,7 @@
using namespace std;
using namespace paraformer;
-ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad)
+ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad, bool use_punc)
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
string model_path;
string cmvn_path;
@@ -18,7 +18,7 @@
}
// PUNC model
- if(true){
+ if(use_punc){
puncHandle = make_unique<CTTransformer>(path, nNumThread);
}
@@ -55,7 +55,6 @@
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#endif
- vector<string> m_strInputNames, m_strOutputNames;
string strName;
getInputName(m_session.get(), strName);
m_strInputNames.push_back(strName.c_str());
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.h b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
index cde2937..9008d10 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.h
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -33,11 +33,12 @@
Ort::Env env_;
Ort::SessionOptions sessionOptions;
+ vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
public:
- ModelImp(const char* path, int nNumThread=0, bool quantize=false, bool use_vad=false);
+ ModelImp(const char* path, int nNumThread=0, bool quantize=false, bool use_vad=false, bool use_punc=false);
~ModelImp();
void reset();
vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index 40d8928..7bfa1a6 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -29,7 +29,7 @@
#include "commonfunc.h"
#include "predefine_coe.h"
#include "tokenizer.h"
-#include "punc_infer.h"
+#include "CT-transformer.h"
#include "FsmnVad.h"
#include "e2e_vad.h"
#include "Vocab.h"
--
Gitblit v1.9.1