From 3e44172c8b927ffc69b585d4fd80b458cb18ba97 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 25 九月 2024 23:43:30 +0800
Subject: [PATCH] update wbsocket for sensevoice & onnx models

---
 runtime/websocket/bin/funasr-wss-server.cpp       |   19 +++++++--
 runtime/websocket/bin/websocket-server.cpp        |   19 ++++++++-
 runtime/websocket/bin/websocket-server.h          |    4 +
 runtime/websocket/bin/funasr-wss-client.cpp       |   21 +++++++---
 runtime/websocket/bin/funasr-wss-server-2pass.cpp |   10 ++--
 5 files changed, 54 insertions(+), 19 deletions(-)

diff --git a/runtime/websocket/bin/funasr-wss-client.cpp b/runtime/websocket/bin/funasr-wss-client.cpp
index 7af3fbb..72e41f3 100644
--- a/runtime/websocket/bin/funasr-wss-client.cpp
+++ b/runtime/websocket/bin/funasr-wss-client.cpp
@@ -115,7 +115,7 @@
 
     // This method will block until the connection is complete  
     void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, 
-             int audio_fs, const std::unordered_map<std::string, int>& hws_map, int use_itn=1) {
+             int audio_fs, const std::unordered_map<std::string, int>& hws_map, int use_itn=1, int svs_itn=1) {
         // Create a new connection to the given URI
         websocketpp::lib::error_code ec;
         typename websocketpp::client<T>::connection_ptr con =
@@ -147,7 +147,7 @@
                 cv.wait(lock);
             }
             total_send += 1;
-            send_wav_data(wav_list[i], wav_ids[i], audio_fs, hws_map, send_hotword, use_itn);
+            send_wav_data(wav_list[i], wav_ids[i], audio_fs, hws_map, send_hotword, use_itn, svs_itn);
             if(send_hotword){
                 send_hotword = false;
             }
@@ -186,7 +186,7 @@
     // send wav to server
     void send_wav_data(string wav_path, string wav_id, int audio_fs,
         const std::unordered_map<std::string, int>& hws_map, 
-        bool send_hotword, bool use_itn) {
+        bool send_hotword, bool use_itn, bool svs_itn) {
         uint64_t count = 0;
         std::stringstream val;
 
@@ -239,8 +239,12 @@
         jsonbegin["wav_format"] = wav_format;
         jsonbegin["audio_fs"] = sampling_rate;
         jsonbegin["itn"] = true;
+        jsonbegin["svs_itn"] = true;
         if(use_itn == 0){
             jsonbegin["itn"] = false;
+        }
+        if(svs_itn == 0){
+            jsonbegin["svs_itn"] = false;
         }
         jsonbegin["is_speaking"] = true;
         if(send_hotword){
@@ -368,6 +372,9 @@
     TCLAP::ValueArg<int> use_itn_(
         "", "use-itn",
         "use-itn is 1 means use itn, 0 means not use itn", false, 1, "int");
+    TCLAP::ValueArg<int> svs_itn_(
+        "", "svs-itn",
+        "svs-itn is 1 means use itn and punc, 0 means not use", false, 1, "int");
     TCLAP::ValueArg<std::string> hotword_("", HOTWORD,
         "the hotword file, one hotword perline, Format: Hotword Weight (could be: 闃块噷宸村反 20)", false, "", "string");
 
@@ -378,6 +385,7 @@
     cmd.add(thread_num_);
     cmd.add(is_ssl_);
     cmd.add(use_itn_);
+    cmd.add(svs_itn_);
     cmd.add(hotword_);
     cmd.parse(argc, argv);
 
@@ -387,6 +395,7 @@
     int threads_num = thread_num_.getValue();
     int is_ssl = is_ssl_.getValue();
     int use_itn = use_itn_.getValue();
+    int svs_itn = svs_itn_.getValue();
 
     std::vector<websocketpp::lib::thread> client_threads;
     std::string uri = "";
@@ -431,17 +440,17 @@
     
     int audio_fs = audio_fs_.getValue();
     for (size_t i = 0; i < threads_num; i++) {
-        client_threads.emplace_back([uri, wav_list, wav_ids, audio_fs, is_ssl, hws_map, use_itn]() {
+        client_threads.emplace_back([uri, wav_list, wav_ids, audio_fs, is_ssl, hws_map, use_itn, svs_itn]() {
           if (is_ssl == 1) {
             WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
 
             c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
 
-            c.run(uri, wav_list, wav_ids, audio_fs, hws_map, use_itn);
+            c.run(uri, wav_list, wav_ids, audio_fs, hws_map, use_itn, svs_itn);
           } else {
             WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
 
-            c.run(uri, wav_list, wav_ids, audio_fs, hws_map, use_itn);
+            c.run(uri, wav_list, wav_ids, audio_fs, hws_map, use_itn, svs_itn);
           }
         });
     }
diff --git a/runtime/websocket/bin/funasr-wss-server-2pass.cpp b/runtime/websocket/bin/funasr-wss-server-2pass.cpp
index d42679b..9c59254 100644
--- a/runtime/websocket/bin/funasr-wss-server-2pass.cpp
+++ b/runtime/websocket/bin/funasr-wss-server-2pass.cpp
@@ -55,11 +55,11 @@
 
     TCLAP::ValueArg<std::string> offline_model_revision(
         "", "offline-model-revision", "ASR offline model revision", false,
-        "v2.0.4", "string");
+        "v2.0.5", "string");
 
     TCLAP::ValueArg<std::string> online_model_revision(
         "", "online-model-revision", "ASR online model revision", false,
-        "v2.0.4", "string");
+        "v2.0.5", "string");
 
     TCLAP::ValueArg<std::string> quantize(
         "", QUANTIZE,
@@ -85,7 +85,7 @@
         "model_quant.onnx, punc.yaml",
         false, "damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx", "string");
     TCLAP::ValueArg<std::string> punc_revision(
-        "", "punc-revision", "PUNC model revision", false, "v2.0.4", "string");
+        "", "punc-revision", "PUNC model revision", false, "v2.0.5", "string");
     TCLAP::ValueArg<std::string> punc_quant(
         "", PUNC_QUANT,
         "true (Default), load the model of model_quant.onnx in punc_dir. If "
@@ -262,7 +262,7 @@
 
         size_t found = s_offline_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
         if (found != std::string::npos) {
-            model_path["offline-model-revision"]="v2.0.4";
+            model_path["offline-model-revision"]="v2.0.5";
         }
 
         found = s_offline_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
@@ -272,7 +272,7 @@
 
         found = s_offline_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
         if (found != std::string::npos) {
-            model_path["model-revision"]="v2.0.4";
+            model_path["model-revision"]="v2.0.5";
             s_itn_path="";
             s_lm_path="";
         }
diff --git a/runtime/websocket/bin/funasr-wss-server.cpp b/runtime/websocket/bin/funasr-wss-server.cpp
index 3c5b81c..956c40e 100644
--- a/runtime/websocket/bin/funasr-wss-server.cpp
+++ b/runtime/websocket/bin/funasr-wss-server.cpp
@@ -50,7 +50,7 @@
     TCLAP::ValueArg<std::string> model_revision(
         "", "model-revision",
         "ASR model revision",
-        false, "v2.0.4", "string");
+        false, "v2.0.5", "string");
     TCLAP::ValueArg<std::string> quantize(
         "", QUANTIZE,
         "true (Default), load the model of model_quant.onnx in model_dir. If set "
@@ -81,7 +81,7 @@
     TCLAP::ValueArg<std::string> punc_revision(
         "", "punc-revision",
         "PUNC model revision",
-        false, "v2.0.4", "string");
+        false, "v2.0.5", "string");
     TCLAP::ValueArg<std::string> punc_quant(
         "", PUNC_QUANT,
         "true (Default), load the model of model_quant.onnx in punc_dir. If set "
@@ -247,7 +247,7 @@
             // modify model-revision by model name
             size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
             if (found != std::string::npos) {
-                model_path["model-revision"]="v2.0.4";
+                model_path["model-revision"]="v2.0.5";
             }
 
             found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
@@ -257,11 +257,22 @@
 
             found = s_asr_path.find("speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
             if (found != std::string::npos) {
-                model_path["model-revision"]="v2.0.4";
+                model_path["model-revision"]="v2.0.5";
                 s_itn_path="";
                 s_lm_path="";
             }
 
+            found = s_asr_path.find(MODEL_SVS);
+            if (found != std::string::npos) {
+                model_path["model-revision"]="v2.0.5";
+                s_itn_path="";
+                model_path[ITN_DIR]="";
+                s_lm_path="";
+                model_path[LM_DIR]="";
+                s_punc_path="";
+                model_path[PUNC_DIR]="";
+            }
+
             if (use_gpu_){
                 model_type = "torchscript";
                 if (s_blade=="true" || s_blade=="True" || s_blade=="TRUE"){
diff --git a/runtime/websocket/bin/websocket-server.cpp b/runtime/websocket/bin/websocket-server.cpp
index 49d8ead..74f1a2e 100644
--- a/runtime/websocket/bin/websocket-server.cpp
+++ b/runtime/websocket/bin/websocket-server.cpp
@@ -67,7 +67,9 @@
                                  bool itn,
                                  int audio_fs,
                                  std::string wav_format,
-                                 FUNASR_DEC_HANDLE& decoder_handle) {
+                                 FUNASR_DEC_HANDLE& decoder_handle,
+                                 std::string svs_lang,
+                                 bool sys_itn) {
   try {
     int num_samples = buffer.size();  // the size of the buf
 
@@ -78,7 +80,8 @@
       try{
         FUNASR_RESULT Result = FunOfflineInferBuffer(
             asr_handle, buffer.data(), buffer.size(), RASR_NONE, nullptr, 
-            hotwords_embedding, audio_fs, wav_format, itn, decoder_handle);
+            hotwords_embedding, audio_fs, wav_format, itn, decoder_handle,
+            svs_lang, sys_itn);
         if (Result != nullptr){
           asr_result = FunASRGetResult(Result, 0);  // get decode result
           stamp_res = FunASRGetStamp(Result);
@@ -162,6 +165,8 @@
   data_msg->msg["audio_fs"] = 16000; // default is 16k
   data_msg->msg["access_num"] = 0; // the number of access for this object, when it is 0, we can free it saftly
   data_msg->msg["is_eof"]=false;
+  data_msg->msg["svs_lang"]="auto";
+  data_msg->msg["svs_itn"]=true;
   FUNASR_DEC_HANDLE decoder_handle =
     FunASRWfstDecoderInit(asr_handle, ASR_OFFLINE, global_beam_, lattice_beam_, am_scale_);
   data_msg->decoder_handle = decoder_handle;
@@ -357,6 +362,12 @@
       if (jsonresult.contains("itn")) {
         msg_data->msg["itn"] = jsonresult["itn"];
       }
+      if (jsonresult.contains("svs_lang")) {
+        msg_data->msg["svs_lang"] = jsonresult["svs_lang"];
+      }
+      if (jsonresult.contains("svs_itn")) {
+        msg_data->msg["svs_itn"] = jsonresult["svs_itn"];
+      }
       if ((jsonresult["is_speaking"] == false ||
           jsonresult["is_finished"] == true) && 
           msg_data->msg["is_eof"] != true && 
@@ -375,7 +386,9 @@
                               msg_data->msg["itn"],
                               msg_data->msg["audio_fs"],
                               msg_data->msg["wav_format"],
-                              std::ref(msg_data->decoder_handle)));
+                              std::ref(msg_data->decoder_handle),
+                              msg_data->msg["svs_lang"],
+                              msg_data->msg["svs_itn"]));
         msg_data->msg["access_num"]=(int)(msg_data->msg["access_num"])+1;
       }
       break;
diff --git a/runtime/websocket/bin/websocket-server.h b/runtime/websocket/bin/websocket-server.h
index c1389bf..dad9cf5 100644
--- a/runtime/websocket/bin/websocket-server.h
+++ b/runtime/websocket/bin/websocket-server.h
@@ -122,7 +122,9 @@
                   bool itn,
                   int audio_fs,
                   std::string wav_format,
-                  FUNASR_DEC_HANDLE& decoder_handle);
+                  FUNASR_DEC_HANDLE& decoder_handle,
+                  std::string svs_lang,
+                  bool sys_itn);
 
   void initAsr(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
   void on_message(websocketpp::connection_hdl hdl, message_ptr msg);

--
Gitblit v1.9.1