zhaomingwork
2025-05-08 ae013cf597db1c523c9fac21b7e83db62304ae2d
runtime/http/bin/connection.hpp
@@ -2,7 +2,7 @@
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2024 by zhaomingwork@qq.com */
/* 2023-2025 by zhaomingwork@qq.com */
//
// copy some codes from  http://www.boost.org/
//
@@ -19,86 +19,376 @@
#include "reply.hpp"
#include <fstream>
#include "file_parse.hpp"
#include <boost/beast.hpp>
#include "model-decoder.h"
namespace beast = boost::beast;
namespace beasthttp = beast::http;
extern std::unordered_map<std::string, int> hws_map_;
extern int fst_inc_wts_;
extern float global_beam_, lattice_beam_, am_scale_;
namespace http {
namespace server2 {
namespace http
{
    namespace server2
    {
/// Represents a single connection from a client.
class connection : public std::enable_shared_from_this<connection> {
 public:
  connection(const connection &) = delete;
  connection &operator=(const connection &) = delete;
  ~connection() { std::cout << "one connection is close()" << std::endl; };
        /// Represents a single connection from a client.
        class connection : public std::enable_shared_from_this<connection>
        {
        public:
            connection(const connection &) = delete;
            connection &operator=(const connection &) = delete;
            ~connection()
            {
                std::cout << "one connection is close()" << std::endl;
            };
  /// Construct a connection with the given socket.
  explicit connection(asio::ip::tcp::socket socket,
                      asio::io_context &io_decoder, int connection_id,
                      std::shared_ptr<ModelDecoder> model_decoder);
            /// Construct a connection with the given socket.
            explicit connection(asio::ip::tcp::socket socket,
                                asio::io_context &io_decoder, int connection_id,
                                std::shared_ptr<ModelDecoder> model_decoder);
            /// Start the first asynchronous operation for the connection.
            void start();
            std::shared_ptr<FUNASR_MESSAGE> &get_data_msg();
            void write_back(std::string str);
            // 处理100 Continue逻辑
            void handle_100_continue()
            {
                //  start_timer(5); // 5秒超时
                auto self = shared_from_this();
                const std::string response =
                    "HTTP/1.1 100 Continue\r\n"
                    "Connection: keep-alive\r\n\r\n";
                asio::async_write(socket_, asio::buffer(response),
                                  [this, self](asio::error_code ec, size_t)
                                  {
                                      if (ec)
                                          return handle_error(ec);
                                      state_ = State::ReadingHeaders;
                                      do_read();
                                  });
            }
            // 准备文件存储
            void prepare_body_handling()
            {
                if (!filename_.empty())
                {
                    sanitize_filename(filename_);
                    output_file_.open(filename_, std::ios::binary);
                    if (!output_file_)
                    {
                        std::cerr << "Failed to open: " << filename_ << "\n";
                        socket_.close();
                    }
                }
            }
            void finalize_request()
            {
                std::cout << "finalize_request" << std::endl;
                send_final_response();
            }
            void send_final_response()
            {
                const std::string response =
                    "HTTP/1.1 200 OK\r\n"
                    "Content-Length: 0\r\n\r\n";
                asio::write(socket_, asio::buffer(response));
                socket_.close();
            }
            void send_417_expectation_failed()
            {
                const std::string response =
                    "HTTP/1.1 417 Expectation Failed\r\n"
                    "Connection: close\r\n\r\n";
                asio::write(socket_, asio::buffer(response));
                socket_.close();
            }
            // 安全处理文件名
            static void sanitize_filename(std::string &name)
            {
                std::replace(name.begin(), name.end(), '/', '_');
                std::replace(name.begin(), name.end(), '\\', '_');
                name = name.substr(name.find_last_of(":") + 1); // 移除潜在路径
            }
            // 协议版本解析
            bool parse_http_version(const std::string &headers)
            {
                size_t start = headers.find("HTTP/");
                if (start == std::string::npos)
                    return false;
                start += 5;
                size_t dot = headers.find('.', start);
                if (dot == std::string::npos)
                    return false;
                try
                {
                    http_version_major_ = std::stoi(headers.substr(start, dot - start));
                    http_version_minor_ = std::stoi(headers.substr(dot + 1, 1));
                    return true;
                }
                catch (...)
                {
                    return false;
                }
            }
            // 头部解析
            bool try_parse_headers()
            {
                size_t header_end = received_data_.find("\r\n\r\n");
                if (header_end == std::string::npos)
                {
                    return false;
                }
                std::string headers = received_data_.substr(0, header_end);
                //  解析内容信息
                if (content_length_ <= 0)
                    content_length_ = parse_content_length(headers);
                // 解析HTTP版本
                if (!parse_http_version(headers))
                {
                    return false;
                }
                // 检查Expect头
                std::string continue100 = "Expect: 100-continue";
                size_t pos = headers.find(continue100);
                expect_100_continue_ = pos != std::string::npos;
                // 检查协议兼容性
                if (expect_100_continue_)
                {
                    headers.erase(pos, continue100.length());
                    received_data_ = headers;
                    state_ = State::SendingContinue;
                    if (http_version_minor_ < 1)
                        send_417_expectation_failed();
                    return true;
                }
                filename_ = parse_attachment_filename(headers);
                // 状态转移
                std::string ext = parese_file_ext(filename_);
                if (filename_.find(".wav") != std::string::npos)
                {
                    std::cout << "set wav_format=pcm, file_name=" << filename_
                              << std::endl;
                    data_msg->msg["wav_format"] = "pcm";
                }
                else
                {
                    std::cout << "set wav_format=" << ext << ", file_name=" << filename_
                              << std::endl;
                    data_msg->msg["wav_format"] = ext;
                }
                data_msg->msg["wav_name"] = filename_;
 
  /// Start the first asynchronous operation for the connection.
  void start();
  std::shared_ptr<FUNASR_MESSAGE> &get_data_msg();
  void write_back(std::string str);
                state_ = State::ReadingBody;
                return true;
            }
 private:
  /// Perform an asynchronous read operation.
  void do_read();
            void parse_multipart_boundary()
            {
                size_t content_type_pos = received_data_.find("Content-Type: multipart/form-data");
                if (content_type_pos == std::string::npos)
                    return;
  /// Perform an asynchronous write operation.
  void do_write();
                size_t boundary_pos = received_data_.find("boundary=", content_type_pos);
                if (boundary_pos == std::string::npos)
                    return;
  void do_decoder();
                boundary_pos += 9; // "boundary="长度
                size_t boundary_end = received_data_.find("\r\n", boundary_pos);
                boundary_ = received_data_.substr(boundary_pos, boundary_end - boundary_pos);
  void setup_timer();
                // 清理boundary的引号
                if (boundary_.front() == '"' && boundary_.back() == '"')
                {
                    boundary_ = boundary_.substr(1, boundary_.size() - 2);
                }
            }
            // multipart 数据处理核心
            void process_multipart_data()
            {
                if (boundary_.empty())
                {
                    parse_multipart_boundary();
                    if (boundary_.empty())
                    {
                        std::cerr << "Invalid multipart format\n";
                        return;
                    }
                }
  /// Socket for the connection.
  asio::ip::tcp::socket socket_;
                while (true)
                {
                    if (!in_file_part_)
                    {
                        // 查找boundary起始
                        size_t boundary_pos = received_data_.find("--" + boundary_);
                        if (boundary_pos == std::string::npos)
                            break;
                        // 移动到part头部
                        size_t part_start = received_data_.find("\r\n\r\n", boundary_pos);
                        if (part_start == std::string::npos)
                            break;
                        part_start += 4; // 跳过空行
                        parse_part_headers(received_data_.substr(boundary_pos, part_start - boundary_pos));
                        received_data_.erase(0, part_start);
                        in_file_part_ = true;
                    }
                    else
                    {
                        // 查找boundary结束
                        size_t boundary_end = received_data_.find("\r\n--" + boundary_);
                        if (boundary_end == std::string::npos)
                            break;
                        // 写入内容
                        std::string tmpstr = received_data_.substr(0, boundary_end);
                        data_msg->samples->insert(data_msg->samples->end(), tmpstr.begin(), tmpstr.end());
                        received_data_.erase(0, boundary_end + 2); // 保留\r\n供下次解析
                        in_file_part_ = false;
                    }
                }
            }
            std::string parese_file_ext(std::string file_name)
            {
                int pos = file_name.rfind('.');
                std::string ext = "";
                if (pos != std::string::npos)
                    ext = file_name.substr(pos + 1);
                return ext;
            }
            // 解析part头部信息
            void parse_part_headers(const std::string &headers)
            {
                current_part_filename_.clear();
                expected_part_size_ = 0;
                // 解析文件名
                size_t filename_pos = headers.find("filename=\"");
                if (filename_pos != std::string::npos)
                {
                    filename_pos += 10;
                    size_t filename_end = headers.find('"', filename_pos);
                    current_part_filename_ = headers.substr(filename_pos, filename_end - filename_pos);
                    sanitize_filename(current_part_filename_);
                }
                // 解析Content-Length
                size_t cl_pos = headers.find("Content-Length: ");
                if (cl_pos != std::string::npos)
                {
                    cl_pos += 15;
                    size_t cl_end = headers.find("\r\n", cl_pos);
                    expected_part_size_ = std::stoull(headers.substr(cl_pos, cl_end - cl_pos));
                }
            }
        private:
            /// Perform an asynchronous read operation.
            void do_read();
            void handle_body();
            std::string parse_attachment_filename(const std::string &header);
            size_t parse_content_length(const std::string &header);
            void handle_error(asio::error_code ec);
            /// Perform an asynchronous write operation.
            void do_write();
            void do_decoder();
            void setup_timer();
            /// Socket for the connection.
            asio::ip::tcp::socket socket_;
            /// Buffer for incoming data.
            std::array<char, 8192> buffer_;
            /// for time out
            std::shared_ptr<asio::steady_timer> s_timer;
            std::shared_ptr<ModelDecoder> model_decoder;
            int connection_id = 0;
            /// The reply to be sent back to the client.
            reply reply_;
            asio::io_context &io_decoder;
            std::shared_ptr<FUNASR_MESSAGE> data_msg;
            std::mutex m_lock;
            std::shared_ptr<asio::io_context::strand> strand_;
 
  /// Buffer for incoming data.
  std::array<char, 8192> buffer_;
  /// for time out
  std::shared_ptr<asio::steady_timer> s_timer;
            beasthttp::response_parser<beasthttp::string_body> parser_; // 渐进式解析器
            std::string received_data_;  // 累积接收的数据
            bool header_parsed_ = false; // 头部解析状态标记
            size_t content_length_ = 0;  // Content-Length 值
            enum class State
            {
                ReadingHeaders,
                SendingContinue,
                ReadingBody
            };
            bool expect_100_continue_ = false;
            State state_ = State::ReadingHeaders;
            std::string filename_;
            std::ofstream output_file_;
            int http_version_major_ = 1;
            int http_version_minor_ = 1;
            std::string boundary_ = "";
            bool in_file_part_ = false;
            std::string current_part_filename_;
            size_t expected_part_size_ = 0;
        };
  std::shared_ptr<ModelDecoder> model_decoder;
        typedef std::shared_ptr<connection> connection_ptr;
    } // namespace server2
} // namespace http
  int connection_id = 0;
  /// The reply to be sent back to the client.
  reply reply_;
  asio::io_context &io_decoder;
  std::shared_ptr<FUNASR_MESSAGE> data_msg;
  std::mutex m_lock;
  std::shared_ptr<asio::io_context::strand> strand_;
  std::shared_ptr<http::server2::file_parser> file_parse;
};
typedef std::shared_ptr<connection> connection_ptr;
}  // namespace server2
}  // namespace http
#endif  // HTTP_SERVER2_CONNECTION_HPP
#endif // HTTP_SERVER2_CONNECTION_HPP