From 8e449d676d4f06ba3de02c451c3b4fa3433a3792 Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期二, 11 四月 2023 10:35:29 +0800
Subject: [PATCH] debug onnxruntime multithread bugs

---
 funasr/runtime/grpc/paraformer_server.cc           |   48 +++-----
 funasr/runtime/onnxruntime/tester/tester_rtf.cpp   |   59 ++++-------
 funasr/runtime/onnxruntime/src/librapidasrapi.cpp  |    8 
 funasr/runtime/onnxruntime/src/FeatureExtract.cpp  |    4 
 funasr/runtime/python/grpc/grpc_main_client.py     |   62 ++++++++++++
 funasr/runtime/onnxruntime/src/paraformer_onnx.h   |    4 
 funasr/runtime/onnxruntime/src/paraformer_onnx.cpp |   22 ++-
 funasr/runtime/onnxruntime/tester/tester.cpp       |   42 --------
 funasr/runtime/grpc/paraformer_server.h            |    6 -
 9 files changed, 126 insertions(+), 129 deletions(-)

diff --git a/funasr/runtime/grpc/paraformer_server.cc b/funasr/runtime/grpc/paraformer_server.cc
index f2ab4e0..c918db7 100644
--- a/funasr/runtime/grpc/paraformer_server.cc
+++ b/funasr/runtime/grpc/paraformer_server.cc
@@ -35,37 +35,25 @@
     init_flag = 0;
 }
 
-void ASRServicer::clear_states(const std::string& user) {
-    clear_buffers(user);
-    clear_transcriptions(user);
-}
-
-void ASRServicer::clear_buffers(const std::string& user) {
-    if (client_buffers.count(user)) {
-        client_buffers.erase(user);
-    }
-}
-
-void ASRServicer::clear_transcriptions(const std::string& user) {
-    if (client_transcription.count(user)) {
-        client_transcription.erase(user);
-    }
-}
-
-void ASRServicer::disconnect(const std::string& user) {
-    clear_states(user);
-    std::cout << "Disconnecting user: " << user << std::endl;
-}
-
 grpc::Status ASRServicer::Recognize(
     grpc::ServerContext* context,
     grpc::ServerReaderWriter<Response, Request>* stream) {
 
     Request req;
+    std::unordered_map<std::string, std::string> client_buffers;
+    std::unordered_map<std::string, std::string> client_transcription;
+    
     while (stream->Read(&req)) {
         if (req.isend()) {
             std::cout << "asr end" << std::endl;
-            disconnect(req.user());
+            // disconnect 
+            if (client_buffers.count(req.user())) {
+                client_buffers.erase(req.user());
+            }
+            if (client_transcription.count(req.user())) {
+                client_transcription.erase(req.user());
+            }
+
             Response res;
             res.set_sentence(
                 R"({"success": true, "detail": "asr end"})"
@@ -103,8 +91,14 @@
                   auto& buf = client_buffers[req.user()];
                   buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
                 }
-                std::string tmp_data = this->client_buffers[req.user()];
-                this->clear_states(req.user());
+                std::string tmp_data = client_buffers[req.user()];
+                // clear_states
+                if (client_buffers.count(req.user())) {
+                    client_buffers.erase(req.user());
+                }
+                if (client_transcription.count(req.user())) {
+                    client_transcription.erase(req.user());
+                }
 
                 Response res;
                 res.set_sentence(
@@ -133,9 +127,6 @@
                     res.set_user(req.user());
                     res.set_action("finish");
                     res.set_language(req.language());
-
-
-
                     stream->Write(res);
                 }
                 else {
@@ -154,7 +145,6 @@
                     res.set_user(req.user());
                     res.set_action("finish");
                     res.set_language(req.language());
-
 
                     stream->Write(res);
                 }
diff --git a/funasr/runtime/grpc/paraformer_server.h b/funasr/runtime/grpc/paraformer_server.h
index e42e041..d385b85 100644
--- a/funasr/runtime/grpc/paraformer_server.h
+++ b/funasr/runtime/grpc/paraformer_server.h
@@ -41,15 +41,9 @@
 class ASRServicer final : public ASR::Service {
   private:
     int init_flag;
-    std::unordered_map<std::string, std::string> client_buffers;
-    std::unordered_map<std::string, std::string> client_transcription;
 
   public:
     ASRServicer(const char* model_path, int thread_num, bool quantize);
-    void clear_states(const std::string& user);
-    void clear_buffers(const std::string& user);
-    void clear_transcriptions(const std::string& user);
-    void disconnect(const std::string& user);
     grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
     RPASR_HANDLE AsrHanlde;
 	
diff --git a/funasr/runtime/onnxruntime/src/FeatureExtract.cpp b/funasr/runtime/onnxruntime/src/FeatureExtract.cpp
index 1b0c3c4..0bcb645 100644
--- a/funasr/runtime/onnxruntime/src/FeatureExtract.cpp
+++ b/funasr/runtime/onnxruntime/src/FeatureExtract.cpp
@@ -12,7 +12,9 @@
 {
     fftwf_free(fft_input);
     fftwf_free(fft_out);
-    fftwf_destroy_plan(p);
+    if(p){
+        fftwf_destroy_plan(p);
+    }
 }
 
 void FeatureExtract::reset()
diff --git a/funasr/runtime/onnxruntime/src/librapidasrapi.cpp b/funasr/runtime/onnxruntime/src/librapidasrapi.cpp
index 62f47a5..b9169bd 100644
--- a/funasr/runtime/onnxruntime/src/librapidasrapi.cpp
+++ b/funasr/runtime/onnxruntime/src/librapidasrapi.cpp
@@ -30,7 +30,7 @@
 		int nStep = 0;
 		int nTotal = audio.get_queue_size();
 		while (audio.fetch(buff, len, flag) > 0) {
-			pRecogObj->reset();
+			//pRecogObj->reset();
 			string msg = pRecogObj->forward(buff, len, flag);
 			pResult->msg += msg;
 			nStep++;
@@ -60,7 +60,7 @@
 		int nStep = 0;
 		int nTotal = audio.get_queue_size();
 		while (audio.fetch(buff, len, flag) > 0) {
-			pRecogObj->reset();
+			//pRecogObj->reset();
 			string msg = pRecogObj->forward(buff, len, flag);
 			pResult->msg += msg;
 			nStep++;
@@ -90,7 +90,7 @@
 		int nStep = 0;
 		int nTotal = audio.get_queue_size();
 		while (audio.fetch(buff, len, flag) > 0) {
-			pRecogObj->reset();
+			//pRecogObj->reset();
 			string msg = pRecogObj->forward(buff, len, flag);
 			pResult->msg += msg;
 			nStep++;
@@ -120,7 +120,7 @@
 		RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
 		pResult->snippet_time = audio.get_time_len();
 		while (audio.fetch(buff, len, flag) > 0) {
-			pRecogObj->reset();
+			//pRecogObj->reset();
 			string msg = pRecogObj->forward(buff, len, flag);
 			pResult->msg+= msg;
 			nStep++;
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index a49069c..bb00849 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -18,7 +18,7 @@
     cmvn_path = pathAppend(path, "am.mvn");
     config_path = pathAppend(path, "config.yaml");
 
-    fe = new FeatureExtract(3);
+    //fe = new FeatureExtract(3);
 
     //sessionOptions.SetInterOpNumThreads(1);
     sessionOptions.SetIntraOpNumThreads(nNumThread);
@@ -52,8 +52,8 @@
 
 ModelImp::~ModelImp()
 {
-    if(fe)
-        delete fe;
+    //if(fe)
+    //    delete fe;
     if (m_session)
     {
         delete m_session;
@@ -65,7 +65,8 @@
 
 void ModelImp::reset()
 {
-    fe->reset();
+    //fe->reset();
+    printf("Not Imp!!!!!!\n");
 }
 
 void ModelImp::apply_lfr(Tensor<float>*& din)
@@ -159,8 +160,9 @@
 
 string ModelImp::forward(float* din, int len, int flag)
 {
-
     Tensor<float>* in;
+    FeatureExtract* fe = new FeatureExtract(3);
+    fe->reset();
     fe->insert(din, len, flag);
     fe->fetch(in);
     apply_lfr(in);
@@ -192,7 +194,6 @@
         auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
         std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
 
-
         int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
         float* floatData = outputTensor[0].GetTensorMutableData<float>();
         auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
@@ -203,9 +204,14 @@
         result = "";
     }
 
-
-    if(in)
+    if(in){
         delete in;
+        in = nullptr;
+    }
+    if(fe){
+        delete fe;
+        fe = nullptr;
+    }
 
     return result;
 }
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.h b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
index 395c328..5d3b0fe 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.h
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -8,7 +8,7 @@
 
     class ModelImp : public Model {
     private:
-        FeatureExtract* fe;
+        //FeatureExtract* fe;
 
         Vocab* vocab;
         vector<float> means_list;
@@ -34,8 +34,6 @@
         vector<string> m_strInputNames, m_strOutputNames;
         vector<const char*> m_szInputNames;
         vector<const char*> m_szOutputNames;
-        //string m_strInputName, m_strInputNameLen;
-        //string m_strOutputName, m_strOutputNameLen;
 
     public:
         ModelImp(const char* path, int nNumThread=0, bool quantize=false);
diff --git a/funasr/runtime/onnxruntime/tester/tester.cpp b/funasr/runtime/onnxruntime/tester/tester.cpp
index 35d534f..5bef810 100644
--- a/funasr/runtime/onnxruntime/tester/tester.cpp
+++ b/funasr/runtime/onnxruntime/tester/tester.cpp
@@ -50,8 +50,7 @@
     {
         string msg = RapidAsrGetResult(Result, 0);
         setbuf(stdout, NULL);
-        cout << "Result: \"";
-        cout << msg << "\"." << endl;
+        printf("Result: %s \n", msg.c_str());
         snippet_time = RapidAsrGetRetSnippetTime(Result);
         RapidAsrFreeResult(Result);
     }
@@ -59,45 +58,6 @@
     {
         cout <<"no return data!";
     }
- 
-    //char* buff = nullptr;
-    //int len = 0;
-    //ifstream ifs(argv[2], std::ios::binary | std::ios::in);
-    //if (ifs.is_open())
-    //{
-    //    ifs.seekg(0, std::ios::end);
-    //    len = ifs.tellg();
-    //    ifs.seekg(0, std::ios::beg);
-
-    //    buff = new char[len];
-
-    //    ifs.read(buff, len);
-
-
-    //    //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
-
-    //    RPASR_RESULT Result=RapidAsrRecogPCMBuffer(AsrHanlde, buff,len, RASR_NONE, NULL);
-    //    //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
-    //    gettimeofday(&end, NULL);
-    //   
-    //    if (Result)
-    //    {
-    //        string msg = RapidAsrGetResult(Result, 0);
-    //        setbuf(stdout, NULL);
-    //        cout << "Result: \"";
-    //        cout << msg << endl;
-    //        cout << "\"." << endl;
-    //        snippet_time = RapidAsrGetRetSnippetTime(Result);
-    //        RapidAsrFreeResult(Result);
-    //    }
-    //    else
-    //    {
-    //        cout <<"no return data!";
-    //    }
-  
-    //   
-    //delete[]buff;
-    //}
  
     printf("Audio length %lfs.\n", (double)snippet_time);
     seconds = (end.tv_sec - start.tv_sec);
diff --git a/funasr/runtime/onnxruntime/tester/tester_rtf.cpp b/funasr/runtime/onnxruntime/tester/tester_rtf.cpp
index 9651900..131ca64 100644
--- a/funasr/runtime/onnxruntime/tester/tester_rtf.cpp
+++ b/funasr/runtime/onnxruntime/tester/tester_rtf.cpp
@@ -11,7 +11,24 @@
 #include <fstream>
 #include <sstream>
 #include <vector>
+#include <thread>
 using namespace std;
+
+void runReg(vector<string> wav_list, RPASR_HANDLE AsrHanlde)
+{
+    for (size_t i = 0; i < wav_list.size(); i++)
+    {
+        RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
+
+        if(Result){
+            string msg = RapidAsrGetResult(Result, 0);
+            printf("Result: %s \n", msg.c_str());
+            RapidAsrFreeResult(Result);
+        }else{
+            cout <<"No return data!";
+        }
+    }
+}
 
 int main(int argc, char *argv[])
 {
@@ -53,46 +70,14 @@
         printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
         exit(-1);
     }
-    gettimeofday(&end, NULL);
-    long seconds = (end.tv_sec - start.tv_sec);
-    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
-    printf("Model initialization takes %lfs.\n", (double)modle_init_micros / 1000000);
-
-    // warm up
-    for (size_t i = 0; i < 30; i++)
-    {
-        RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
-    }
-
-    // forward
-    float snippet_time = 0.0f;
-    float total_length = 0.0f;
-    long total_time = 0.0f;
     
-    for (size_t i = 0; i < wav_list.size(); i++)
-    {
-        gettimeofday(&start, NULL);
-        RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
-        gettimeofday(&end, NULL);
-        seconds = (end.tv_sec - start.tv_sec);
-        long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
-        total_time += taking_micros;
+    std::thread t1(runReg, wav_list, AsrHanlde);
+    std::thread t2(runReg, wav_list, AsrHanlde);
 
-        if(Result){
-            string msg = RapidAsrGetResult(Result, 0);
-            printf("Result: %s \n", msg);
+    t1.join();
+    t2.join();
 
-            snippet_time = RapidAsrGetRetSnippetTime(Result);
-            total_length += snippet_time;
-            RapidAsrFreeResult(Result);
-        }else{
-            cout <<"No return data!";
-        }
-    }
-
-    printf("total_time_wav %ld ms.\n", (long)(total_length * 1000));
-    printf("total_time_comput %ld ms.\n", total_time / 1000);
-    printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
+    //runReg(wav_list, AsrHanlde);
 
     RapidAsrUninit(AsrHanlde);
     return 0;
diff --git a/funasr/runtime/python/grpc/grpc_main_client.py b/funasr/runtime/python/grpc/grpc_main_client.py
new file mode 100644
index 0000000..b6491df
--- /dev/null
+++ b/funasr/runtime/python/grpc/grpc_main_client.py
@@ -0,0 +1,62 @@
+import grpc
+import json
+import time
+import asyncio
+import soundfile as sf
+import argparse
+
+from grpc_client import transcribe_audio_bytes
+from paraformer_pb2_grpc import ASRStub
+
+# send the audio data once
+async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
+    with grpc.insecure_channel(grpc_uri) as channel:
+        stub = ASRStub(channel)
+        for line in wav_scp:
+            wav_file = line.split()[1]
+            wav, _ = sf.read(wav_file, dtype='int16')
+            
+            b = time.time()
+            response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
+            resp = response.next()
+            text = ''
+            if 'decoding' == resp.action:
+                resp = response.next()
+                if 'finish' == resp.action:
+                    text = json.loads(resp.sentence)['text']
+            response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
+            res= {'text': text, 'time': time.time() - b}
+            print(res)
+
+async def test(args):
+    wav_scp = open(args.wav_scp, "r").readlines()
+    uri = '{}:{}'.format(args.host, args.port)
+    res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--host",
+                        type=str,
+                        default="127.0.0.1",
+                        required=False,
+                        help="grpc server host ip")
+    parser.add_argument("--port",
+                        type=int,
+                        default=10108,
+                        required=False,
+                        help="grpc server port")              
+    parser.add_argument("--user_allowed",
+                        type=str,
+                        default="project1_user1",
+                        help="allowed user for grpc client")
+    parser.add_argument("--sample_rate",
+                        type=int,
+                        default=16000,
+                        help="audio sample_rate from client") 
+    parser.add_argument("--wav_scp",
+                        type=str,
+                        required=True,
+                        help="audio wav scp")                    
+    args = parser.parse_args()
+    
+    asyncio.run(test(args))

--
Gitblit v1.9.1