From 009f4fed79b14def6c42c98d9e481062bce9a2c7 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 15 十二月 2023 10:30:26 +0800
Subject: [PATCH] update funasr-wss-client

---
 runtime/websocket/bin/funasr-wss-client-2pass.cpp |   30 +++++++++++++++++-------------
 1 files changed, 17 insertions(+), 13 deletions(-)

diff --git a/runtime/websocket/bin/funasr-wss-client-2pass.cpp b/runtime/websocket/bin/funasr-wss-client-2pass.cpp
index e2cce28..2bd814c 100644
--- a/runtime/websocket/bin/funasr-wss-client-2pass.cpp
+++ b/runtime/websocket/bin/funasr-wss-client-2pass.cpp
@@ -122,7 +122,7 @@
 
   // This method will block until the connection is complete
   void run(const std::string& uri, const std::vector<string>& wav_list,
-           const std::vector<string>& wav_ids, std::string asr_mode,
+           const std::vector<string>& wav_ids, int audio_fs, std::string asr_mode,
            std::vector<int> chunk_size, const std::unordered_map<std::string, int>& hws_map,
            bool is_record=false, int use_itn=1) {
     // Create a new connection to the given URI
@@ -148,7 +148,7 @@
     if(is_record){
       send_rec_data(asr_mode, chunk_size, hws_map, use_itn);
     }else{
-      send_wav_data(wav_list[0], wav_ids[0], asr_mode, chunk_size, hws_map, use_itn);
+      send_wav_data(wav_list[0], wav_ids[0], audio_fs, asr_mode, chunk_size, hws_map, use_itn);
     }
 
     WaitABit();
@@ -183,20 +183,19 @@
     m_done = true;
   }
   // send wav to server
-  void send_wav_data(string wav_path, string wav_id, std::string asr_mode,
+  void send_wav_data(string wav_path, string wav_id, int audio_fs, std::string asr_mode,
                      std::vector<int> chunk_vector, const std::unordered_map<std::string, int>& hws_map,
                      int use_itn) {
     uint64_t count = 0;
     std::stringstream val;
 
     funasr::Audio audio(1);
-    int32_t sampling_rate = 16000;
+    int32_t sampling_rate = audio_fs;
     std::string wav_format = "pcm";
     if (funasr::IsTargetFile(wav_path.c_str(), "wav")) {
-      int32_t sampling_rate = -1;
-      if (!audio.LoadWav(wav_path.c_str(), &sampling_rate)) return;
+      if (!audio.LoadWav(wav_path.c_str(), &sampling_rate, false)) return;
     } else if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) {
-      if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate)) return;
+      if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate, false)) return;
     } else {
       wav_format = "others";
       if (!audio.LoadOthers2Char(wav_path.c_str())) return;
@@ -238,6 +237,7 @@
     jsonbegin["chunk_size"] = chunk_size;
     jsonbegin["wav_name"] = wav_id;
     jsonbegin["wav_format"] = wav_format;
+    jsonbegin["audio_fs"] = sampling_rate;
     jsonbegin["is_speaking"] = true;
     jsonbegin["itn"] = true;
     if(use_itn == 0){
@@ -360,6 +360,7 @@
     }
     websocketpp::lib::error_code ec;
 
+    float sample_rate = 16000;
     nlohmann::json jsonbegin;
     nlohmann::json chunk_size = nlohmann::json::array();
     chunk_size.push_back(chunk_vector[0]);
@@ -369,6 +370,7 @@
     jsonbegin["chunk_size"] = chunk_size;
     jsonbegin["wav_name"] = "record";
     jsonbegin["wav_format"] = "pcm";
+    jsonbegin["audio_fs"] = sample_rate;
     jsonbegin["is_speaking"] = true;
     jsonbegin["itn"] = true;
     if(use_itn == 0){
@@ -408,7 +410,6 @@
 
     param.suggestedLatency = info->defaultLowInputLatency;
     param.hostApiSpecificStreamInfo = nullptr;
-    float sample_rate = 16000;
 
     PaStream *stream;
     std::vector<float> buffer;
@@ -486,6 +487,7 @@
       "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: "
       "asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
       false, "", "string");
+  TCLAP::ValueArg<std::int32_t> audio_fs_("", "audio-fs", "the sample rate of audio", false, 16000, "int32_t");
   TCLAP::ValueArg<int> record_(
       "", "record",
       "record is 1 means use record", false, 0,
@@ -511,6 +513,7 @@
   cmd.add(server_ip_);
   cmd.add(port_);
   cmd.add(wav_path_);
+  cmd.add(audio_fs_);
   cmd.add(asr_mode_);
   cmd.add(record_);
   cmd.add(chunk_size_);
@@ -558,6 +561,7 @@
       funasr::ExtractHws(hotword_path, hws_map);
   }
 
+  int audio_fs = audio_fs_.getValue();
   if(is_record == 1){
       std::vector<string> tmp_wav_list;
       std::vector<string> tmp_wav_ids;
@@ -567,11 +571,11 @@
 
         c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
 
-        c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, hws_map, true, use_itn);
+        c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn);
       } else {
         WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
 
-        c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, hws_map, true, use_itn);
+        c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, true, use_itn);
       }
 
   }else{
@@ -612,17 +616,17 @@
         tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
 
         client_threads.emplace_back(
-            [uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, is_ssl, hws_map, use_itn]() {
+            [uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, is_ssl, hws_map, use_itn]() {
               if (is_ssl == 1) {
                 WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
 
                 c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
 
-                c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, hws_map, false, use_itn);
+                c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn);
               } else {
                 WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
 
-                c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, hws_map, false, use_itn);
+                c.run(uri, tmp_wav_list, tmp_wav_ids, audio_fs, asr_mode, chunk_size, hws_map, false, use_itn);
               }
             });
       }

--
Gitblit v1.9.1