From 13d57c776de1692200a6abfee88db1f9e82ee41d Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期二, 21 三月 2023 16:32:53 +0800
Subject: [PATCH] add onnx quantize model for grpc

---
 funasr/runtime/grpc/paraformer_server.cc |   17 ++++++++++-------
 funasr/runtime/grpc/paraformer_server.h  |    2 +-
 2 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/funasr/runtime/grpc/paraformer_server.cc b/funasr/runtime/grpc/paraformer_server.cc
index e5814a5..69ce903 100644
--- a/funasr/runtime/grpc/paraformer_server.cc
+++ b/funasr/runtime/grpc/paraformer_server.cc
@@ -29,8 +29,8 @@
 using paraformer::Response;
 using paraformer::ASR;
 
-ASRServicer::ASRServicer(const char* model_path, int thread_num) {
-    AsrHanlde=RapidAsrInit(model_path, thread_num);
+ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
+    AsrHanlde=RapidAsrInit(model_path, thread_num, quantize);
     std::cout << "ASRServicer init" << std::endl;
     init_flag = 0;
 }
@@ -170,10 +170,10 @@
 }
 
 
-void RunServer(const std::string& port, int thread_num, const char* model_path) {
+void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
     std::string server_address;
     server_address = "0.0.0.0:" + port;
-    ASRServicer service(model_path, thread_num);
+    ASRServicer service(model_path, thread_num, quantize);
 
     ServerBuilder builder;
     builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
@@ -184,12 +184,15 @@
 }
 
 int main(int argc, char* argv[]) {
-    if (argc < 3)
+    if (argc < 5)
     {
-        printf("Usage: %s port thread_num /path/to/model_file\n", argv[0]);
+        printf("Usage: %s port thread_num /path/to/model_file quantize(true or false) \n", argv[0]);
         exit(-1);
     }
 
-    RunServer(argv[1], atoi(argv[2]), argv[3]);
+    // is quantize
+    bool quantize = false;
+    std::istringstream(argv[4]) >> std::boolalpha >> quantize;
+    RunServer(argv[1], atoi(argv[2]), argv[3], quantize);
     return 0;
 }
diff --git a/funasr/runtime/grpc/paraformer_server.h b/funasr/runtime/grpc/paraformer_server.h
index f356d94..e42e041 100644
--- a/funasr/runtime/grpc/paraformer_server.h
+++ b/funasr/runtime/grpc/paraformer_server.h
@@ -45,7 +45,7 @@
     std::unordered_map<std::string, std::string> client_transcription;
 
   public:
-    ASRServicer(const char* model_path, int thread_num);
+    ASRServicer(const char* model_path, int thread_num, bool quantize);
     void clear_states(const std::string& user);
     void clear_buffers(const std::string& user);
     void clear_transcriptions(const std::string& user);

--
Gitblit v1.9.1