From c2e232451f2f87b1ebdddd6a7f6d8434cb309808 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 07 九月 2023 14:30:12 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/runtime/websocket/funasr-wss-client-2pass.cpp |  273 ++++++++++++++++++++++++++++++++++++++++++------------
 1 files changed, 211 insertions(+), 62 deletions(-)

diff --git a/funasr/runtime/websocket/funasr-wss-client-2pass.cpp b/funasr/runtime/websocket/funasr-wss-client-2pass.cpp
index e52e316..9010c86 100644
--- a/funasr/runtime/websocket/funasr-wss-client-2pass.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client-2pass.cpp
@@ -17,6 +17,7 @@
 
 #define ASIO_STANDALONE 1
 #include <glog/logging.h>
+#include "portaudio.h" 
 
 #include <atomic>
 #include <fstream>
@@ -30,6 +31,7 @@
 #include "audio.h"
 #include "nlohmann/json.hpp"
 #include "tclap/CmdLine.h"
+#include "microphone.h"
 
 /**
  * Define a semi-cross platform helper method that waits/sleeps for a bit.
@@ -123,7 +125,6 @@
           if (ec) {
             LOG(ERROR) << "Error closing connection " << ec.message();
           }
-       
         }
     }
   }
@@ -131,7 +132,7 @@
   // This method will block until the connection is complete
   void run(const std::string& uri, const std::vector<string>& wav_list,
            const std::vector<string>& wav_ids, std::string asr_mode,
-           std::vector<int> chunk_size) {
+           std::vector<int> chunk_size, bool is_record=false) {
     // Create a new connection to the given URI
     websocketpp::lib::error_code ec;
     typename websocketpp::client<T>::connection_ptr con =
@@ -152,8 +153,11 @@
     // Create a thread to run the ASIO io_service event loop
     websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
                                          &m_client);
-
-    send_wav_data(wav_list[0], wav_ids[0], asr_mode, chunk_size);
+    if(is_record){
+      send_rec_data(asr_mode, chunk_size);
+    }else{
+      send_wav_data(wav_list[0], wav_ids[0], asr_mode, chunk_size);
+    }
 
     WaitABit();
 
@@ -264,16 +268,11 @@
             send_block = len - offset;
           }
           m_client.send(m_hdl, iArray + offset, send_block * sizeof(short),
-                        websocketpp::frame::opcode::binary, ec);
+                websocketpp::frame::opcode::binary, ec);
           offset += send_block;
         }
 
         LOG(INFO) << "sended data len=" << len * sizeof(short);
-        // The most likely error that we will get is that the connection is
-        // not in the right state. Usually this means we tried to send a
-        // message to a connection that was closed or in the process of
-        // closing. While many errors here can be easily recovered from,
-        // in this simple example, we'll stop the data loop.
         if (ec) {
           m_client.get_alog().write(websocketpp::log::alevel::app,
                                     "Send Error: " + ec.message());
@@ -300,11 +299,6 @@
       }
 
       LOG(INFO) << "sended data len=" << len;
-      // The most likely error that we will get is that the connection is
-      // not in the right state. Usually this means we tried to send a
-      // message to a connection that was closed or in the process of
-      // closing. While many errors here can be easily recovered from,
-      // in this simple example, we'll stop the data loop.
       if (ec) {
         m_client.get_alog().write(websocketpp::log::alevel::app,
                                   "Send Error: " + ec.message());
@@ -317,6 +311,137 @@
                   ec);
     WaitABit();
   }
+
+  static int RecordCallback(const void* inputBuffer, void* outputBuffer,
+      unsigned long framesPerBuffer, const PaStreamCallbackTimeInfo* timeInfo,
+      PaStreamCallbackFlags statusFlags, void* userData)
+  {
+      std::vector<float>* buffer = static_cast<std::vector<float>*>(userData);
+      const float* input = static_cast<const float*>(inputBuffer);
+
+      for (unsigned int i = 0; i < framesPerBuffer; i++)
+      {
+          buffer->push_back(input[i]);
+      }
+
+      return paContinue;
+  }
+
+  void send_rec_data(std::string asr_mode, std::vector<int> chunk_vector) {
+    // first message
+    bool wait = false;
+    while (1) {
+      {
+        scoped_lock guard(m_lock);
+        // If the connection has been closed, stop generating data
+        if (m_done) {
+          break;
+        }
+        // If the connection hasn't been opened yet wait a bit and retry
+        if (!m_open) {
+          wait = true;
+        } else {
+          break;
+        }
+      }
+
+      if (wait) {
+        // LOG(INFO) << "wait.." << m_open;
+        WaitABit();
+        continue;
+      }
+    }
+    websocketpp::lib::error_code ec;
+
+    nlohmann::json jsonbegin;
+    nlohmann::json chunk_size = nlohmann::json::array();
+    chunk_size.push_back(chunk_vector[0]);
+    chunk_size.push_back(chunk_vector[1]);
+    chunk_size.push_back(chunk_vector[2]);
+    jsonbegin["mode"] = asr_mode;
+    jsonbegin["chunk_size"] = chunk_size;
+    jsonbegin["wav_name"] = "record";
+    jsonbegin["wav_format"] = "pcm";
+    jsonbegin["is_speaking"] = true;
+    m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
+                  ec);
+    // mic
+    Microphone mic;
+    PaDeviceIndex num_devices = Pa_GetDeviceCount();
+    LOG(INFO) << "Num devices: " << num_devices;
+
+    PaStreamParameters param;
+
+    param.device = Pa_GetDefaultInputDevice();
+    if (param.device == paNoDevice) {
+      LOG(INFO) << "No default input device found";
+      exit(EXIT_FAILURE);
+    }
+    LOG(INFO) << "Use default device: " << param.device;
+
+    const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device);
+    LOG(INFO) << "  Name: " << info->name;
+    LOG(INFO) << "  Max input channels: " << info->maxInputChannels;
+
+    param.channelCount = 1;
+    param.sampleFormat = paFloat32;
+
+    param.suggestedLatency = info->defaultLowInputLatency;
+    param.hostApiSpecificStreamInfo = nullptr;
+    float sample_rate = 16000;
+
+    PaStream *stream;
+    std::vector<float> buffer;
+    PaError err =
+        Pa_OpenStream(&stream, &param, nullptr, /* &outputParameters, */
+                      sample_rate,
+                      0,          // frames per buffer
+                      paClipOff,  // we won't output out of range samples
+                                  // so don't bother clipping them
+                      RecordCallback, &buffer);
+    if (err != paNoError) {
+      LOG(ERROR) << "portaudio error: " << Pa_GetErrorText(err);
+      exit(EXIT_FAILURE);
+    }
+
+    err = Pa_StartStream(stream);
+    LOG(INFO) << "Started: ";
+
+    if (err != paNoError) {
+      LOG(ERROR) << "portaudio error: " << Pa_GetErrorText(err);
+      exit(EXIT_FAILURE);
+    }
+
+    while(true){
+      int len = buffer.size();
+      short* iArray = new short[len];
+      for (size_t i = 0; i < len; ++i) {
+        iArray[i] = (short)(buffer[i] * 32768);
+      }
+
+      m_client.send(m_hdl, iArray, len * sizeof(short),
+                    websocketpp::frame::opcode::binary, ec);
+      buffer.clear();
+
+      if (ec) {
+        m_client.get_alog().write(websocketpp::log::alevel::app,
+                                  "Send Error: " + ec.message());
+      }
+      Pa_Sleep(20);  // sleep for 20ms
+    }
+
+    nlohmann::json jsonresult;
+    jsonresult["is_speaking"] = false;
+    m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
+                  ec);
+    
+    err = Pa_CloseStream(stream);
+    if (err != paNoError) {
+      LOG(INFO) << "portaudio error: " << Pa_GetErrorText(err);
+      exit(EXIT_FAILURE);
+    }
+  }
+
   websocketpp::client<T> m_client;
 
  private:
@@ -331,7 +456,7 @@
   google::InitGoogleLogging(argv[0]);
   FLAGS_logtostderr = true;
 
-  TCLAP::CmdLine cmd("funasr-wss-client", ' ', "1.0");
+  TCLAP::CmdLine cmd("funasr-wss-client-2pass", ' ', "1.0");
   TCLAP::ValueArg<std::string> server_ip_("", "server-ip", "server-ip", true,
                                           "127.0.0.1", "string");
   TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095",
@@ -340,7 +465,11 @@
       "", "wav-path",
       "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: "
       "asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
-      true, "", "string");
+      false, "", "string");
+  TCLAP::ValueArg<int> record_(
+      "", "record",
+      "record is 1 means use record", false, 0,
+      "int");
   TCLAP::ValueArg<std::string> asr_mode_("", ASR_MODE, "offline, online, 2pass",
                                          false, "2pass", "string");
   TCLAP::ValueArg<std::string> chunk_size_("", "chunk-size",
@@ -357,6 +486,7 @@
   cmd.add(port_);
   cmd.add(wav_path_);
   cmd.add(asr_mode_);
+  cmd.add(record_);
   cmd.add(chunk_size_);
   cmd.add(thread_num_);
   cmd.add(is_ssl_);
@@ -382,6 +512,7 @@
 
   int threads_num = thread_num_.getValue();
   int is_ssl = is_ssl_.getValue();
+  int is_record = record_.getValue();
 
   std::string uri = "";
   if (is_ssl == 1) {
@@ -390,60 +521,78 @@
     uri = "ws://" + server_ip + ":" + port;
   }
 
-  // read wav_path
-  std::vector<string> wav_list;
-  std::vector<string> wav_ids;
-  string default_id = "wav_default_id";
-  if (IsTargetFile(wav_path, "scp")) {
-    ifstream in(wav_path);
-    if (!in.is_open()) {
-      printf("Failed to open scp file");
-      return 0;
-    }
-    string line;
-    while (getline(in, line)) {
-      istringstream iss(line);
-      string column1, column2;
-      iss >> column1 >> column2;
-      wav_list.emplace_back(column2);
-      wav_ids.emplace_back(column1);
-    }
-    in.close();
-  } else {
-    wav_list.emplace_back(wav_path);
-    wav_ids.emplace_back(default_id);
-  }
-
-  for (size_t wav_i = 0; wav_i < wav_list.size(); wav_i = wav_i + threads_num) {
-    std::vector<websocketpp::lib::thread> client_threads;
-    for (size_t i = 0; i < threads_num; i++) {
-      if (wav_i + i >= wav_list.size()) {
-        break;
-      }
+  if(is_record == 1){
       std::vector<string> tmp_wav_list;
       std::vector<string> tmp_wav_ids;
 
-      tmp_wav_list.emplace_back(wav_list[wav_i + i]);
-      tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
+      if (is_ssl == 1) {
+        WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
 
-      client_threads.emplace_back(
-          [uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, is_ssl]() {
-            if (is_ssl == 1) {
-              WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
+        c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
 
-              c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
+        c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, true);
+      } else {
+        WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
 
-              c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
-            } else {
-              WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
+        c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, true);
+      }
 
-              c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
-            }
-          });
+  }else{
+    // read wav_path
+    std::vector<string> wav_list;
+    std::vector<string> wav_ids;
+    string default_id = "wav_default_id";
+    if (IsTargetFile(wav_path, "scp")) {
+      ifstream in(wav_path);
+      if (!in.is_open()) {
+        printf("Failed to open scp file");
+        return 0;
+      }
+      string line;
+      while (getline(in, line)) {
+        istringstream iss(line);
+        string column1, column2;
+        iss >> column1 >> column2;
+        wav_list.emplace_back(column2);
+        wav_ids.emplace_back(column1);
+      }
+      in.close();
+    } else {
+      wav_list.emplace_back(wav_path);
+      wav_ids.emplace_back(default_id);
     }
 
-    for (auto& t : client_threads) {
-      t.join();
+    for (size_t wav_i = 0; wav_i < wav_list.size(); wav_i = wav_i + threads_num) {
+      std::vector<websocketpp::lib::thread> client_threads;
+      for (size_t i = 0; i < threads_num; i++) {
+        if (wav_i + i >= wav_list.size()) {
+          break;
+        }
+        std::vector<string> tmp_wav_list;
+        std::vector<string> tmp_wav_ids;
+
+        tmp_wav_list.emplace_back(wav_list[wav_i + i]);
+        tmp_wav_ids.emplace_back(wav_ids[wav_i + i]);
+
+        client_threads.emplace_back(
+            [uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size, is_ssl]() {
+              if (is_ssl == 1) {
+                WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
+
+                c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
+
+                c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
+              } else {
+                WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
+
+                c.run(uri, tmp_wav_list, tmp_wav_ids, asr_mode, chunk_size);
+              }
+            });
+      }
+
+      for (auto& t : client_threads) {
+        t.join();
+      }
     }
   }
 }

--
Gitblit v1.9.1