/** * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights * Reserved. MIT License (https://opensource.org/licenses/MIT) */ /* 2023-2025 by zhaomingwork@qq.com */ // // copy some codes from http://www.boost.org/ // #ifndef HTTP_SERVER2_CONNECTION_HPP #define HTTP_SERVER2_CONNECTION_HPP #include #include #include #include #include #include "reply.hpp" #include #include #include "model-decoder.h" namespace beast = boost::beast; namespace beasthttp = beast::http; extern std::unordered_map hws_map_; extern int fst_inc_wts_; extern float global_beam_, lattice_beam_, am_scale_; namespace http { namespace server2 { /// Represents a single connection from a client. class connection : public std::enable_shared_from_this { public: connection(const connection &) = delete; connection &operator=(const connection &) = delete; ~connection() { std::cout << "one connection is close()" << std::endl; }; /// Construct a connection with the given socket. explicit connection(asio::ip::tcp::socket socket, asio::io_context &io_decoder, int connection_id, std::shared_ptr model_decoder); /// Start the first asynchronous operation for the connection. void start(); std::shared_ptr &get_data_msg(); void write_back(std::string str); // 处理100 Continue逻辑 void handle_100_continue() { // start_timer(5); // 5秒超时 auto self = shared_from_this(); const std::string response = "HTTP/1.1 100 Continue\r\n" "Connection: keep-alive\r\n\r\n"; asio::async_write(socket_, asio::buffer(response), [this, self](asio::error_code ec, size_t) { if (ec) return handle_error(ec); state_ = State::ReadingHeaders; do_read(); }); } // 准备文件存储 void prepare_body_handling() { if (!filename_.empty()) { sanitize_filename(filename_); output_file_.open(filename_, std::ios::binary); if (!output_file_) { std::cerr << "Failed to open: " << filename_ << "\n"; socket_.close(); } } } void finalize_request() { std::cout << "finalize_request" << std::endl; send_final_response(); } void send_final_response() { const std::string response = "HTTP/1.1 200 OK\r\n" "Content-Length: 0\r\n\r\n"; asio::write(socket_, asio::buffer(response)); socket_.close(); } void send_417_expectation_failed() { const std::string response = "HTTP/1.1 417 Expectation Failed\r\n" "Connection: close\r\n\r\n"; asio::write(socket_, asio::buffer(response)); socket_.close(); } // 安全处理文件名 static void sanitize_filename(std::string &name) { std::replace(name.begin(), name.end(), '/', '_'); std::replace(name.begin(), name.end(), '\\', '_'); name = name.substr(name.find_last_of(":") + 1); // 移除潜在路径 } // 协议版本解析 bool parse_http_version(const std::string &headers) { size_t start = headers.find("HTTP/"); if (start == std::string::npos) return false; start += 5; size_t dot = headers.find('.', start); if (dot == std::string::npos) return false; try { http_version_major_ = std::stoi(headers.substr(start, dot - start)); http_version_minor_ = std::stoi(headers.substr(dot + 1, 1)); return true; } catch (...) { return false; } } // 头部解析 bool try_parse_headers() { size_t header_end = received_data_.find("\r\n\r\n"); if (header_end == std::string::npos) { return false; } std::string headers = received_data_.substr(0, header_end); // 解析内容信息 if (content_length_ <= 0) content_length_ = parse_content_length(headers); // 解析HTTP版本 if (!parse_http_version(headers)) { return false; } // 检查Expect头 std::string continue100 = "Expect: 100-continue"; size_t pos = headers.find(continue100); expect_100_continue_ = pos != std::string::npos; // 检查协议兼容性 if (expect_100_continue_) { headers.erase(pos, continue100.length()); received_data_ = headers; state_ = State::SendingContinue; if (http_version_minor_ < 1) send_417_expectation_failed(); return true; } filename_ = parse_attachment_filename(headers); // 状态转移 std::string ext = parese_file_ext(filename_); if (filename_.find(".wav") != std::string::npos) { std::cout << "set wav_format=pcm, file_name=" << filename_ << std::endl; data_msg->msg["wav_format"] = "pcm"; } else { std::cout << "set wav_format=" << ext << ", file_name=" << filename_ << std::endl; data_msg->msg["wav_format"] = ext; } data_msg->msg["wav_name"] = filename_; state_ = State::ReadingBody; return true; } void parse_multipart_boundary() { size_t content_type_pos = received_data_.find("Content-Type: multipart/form-data"); if (content_type_pos == std::string::npos) return; size_t boundary_pos = received_data_.find("boundary=", content_type_pos); if (boundary_pos == std::string::npos) return; boundary_pos += 9; // "boundary="长度 size_t boundary_end = received_data_.find("\r\n", boundary_pos); boundary_ = received_data_.substr(boundary_pos, boundary_end - boundary_pos); // 清理boundary的引号 if (boundary_.front() == '"' && boundary_.back() == '"') { boundary_ = boundary_.substr(1, boundary_.size() - 2); } } // multipart 数据处理核心 void process_multipart_data() { if (boundary_.empty()) { parse_multipart_boundary(); if (boundary_.empty()) { std::cerr << "Invalid multipart format\n"; return; } } while (true) { if (!in_file_part_) { // 查找boundary起始 size_t boundary_pos = received_data_.find("--" + boundary_); if (boundary_pos == std::string::npos) break; // 移动到part头部 size_t part_start = received_data_.find("\r\n\r\n", boundary_pos); if (part_start == std::string::npos) break; part_start += 4; // 跳过空行 parse_part_headers(received_data_.substr(boundary_pos, part_start - boundary_pos)); received_data_.erase(0, part_start); in_file_part_ = true; } else { // 查找boundary结束 size_t boundary_end = received_data_.find("\r\n--" + boundary_); if (boundary_end == std::string::npos) break; // 写入内容 std::string tmpstr = received_data_.substr(0, boundary_end); data_msg->samples->insert(data_msg->samples->end(), tmpstr.begin(), tmpstr.end()); received_data_.erase(0, boundary_end + 2); // 保留\r\n供下次解析 in_file_part_ = false; } } } std::string parese_file_ext(std::string file_name) { int pos = file_name.rfind('.'); std::string ext = ""; if (pos != std::string::npos) ext = file_name.substr(pos + 1); return ext; } // 解析part头部信息 void parse_part_headers(const std::string &headers) { current_part_filename_.clear(); expected_part_size_ = 0; // 解析文件名 size_t filename_pos = headers.find("filename=\""); if (filename_pos != std::string::npos) { filename_pos += 10; size_t filename_end = headers.find('"', filename_pos); current_part_filename_ = headers.substr(filename_pos, filename_end - filename_pos); sanitize_filename(current_part_filename_); } // 解析Content-Length size_t cl_pos = headers.find("Content-Length: "); if (cl_pos != std::string::npos) { cl_pos += 15; size_t cl_end = headers.find("\r\n", cl_pos); expected_part_size_ = std::stoull(headers.substr(cl_pos, cl_end - cl_pos)); } } private: /// Perform an asynchronous read operation. void do_read(); void handle_body(); std::string parse_attachment_filename(const std::string &header); size_t parse_content_length(const std::string &header); void handle_error(asio::error_code ec); /// Perform an asynchronous write operation. void do_write(); void do_decoder(); void setup_timer(); /// Socket for the connection. asio::ip::tcp::socket socket_; /// Buffer for incoming data. std::array buffer_; /// for time out std::shared_ptr s_timer; std::shared_ptr model_decoder; int connection_id = 0; /// The reply to be sent back to the client. reply reply_; asio::io_context &io_decoder; std::shared_ptr data_msg; std::mutex m_lock; std::shared_ptr strand_; beasthttp::response_parser parser_; // 渐进式解析器 std::string received_data_; // 累积接收的数据 bool header_parsed_ = false; // 头部解析状态标记 size_t content_length_ = 0; // Content-Length 值 enum class State { ReadingHeaders, SendingContinue, ReadingBody }; bool expect_100_continue_ = false; State state_ = State::ReadingHeaders; std::string filename_; std::ofstream output_file_; int http_version_major_ = 1; int http_version_minor_ = 1; std::string boundary_ = ""; bool in_file_part_ = false; std::string current_part_filename_; size_t expected_part_size_ = 0; }; typedef std::shared_ptr connection_ptr; } // namespace server2 } // namespace http #endif // HTTP_SERVER2_CONNECTION_HPP