zhaomingwork
2025-05-08 ae013cf597db1c523c9fac21b7e83db62304ae2d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
 * Reserved. MIT License  (https://opensource.org/licenses/MIT)
 */
/* 2023-2025 by zhaomingwork@qq.com */
//
// copy some codes from  http://www.boost.org/
//
 
#ifndef HTTP_SERVER2_CONNECTION_HPP
#define HTTP_SERVER2_CONNECTION_HPP
 
#include <array>
#include <asio.hpp>
#include <atomic>
#include <iostream>
#include <memory>
 
#include "reply.hpp"
 
#include <fstream>
 
#include <boost/beast.hpp>
 
#include "model-decoder.h"
 
namespace beast = boost::beast;
namespace beasthttp = beast::http;
 
extern std::unordered_map<std::string, int> hws_map_;
extern int fst_inc_wts_;
extern float global_beam_, lattice_beam_, am_scale_;
 
namespace http
{
    namespace server2
    {
 
        /// Represents a single connection from a client.
        class connection : public std::enable_shared_from_this<connection>
        {
        public:
            connection(const connection &) = delete;
            connection &operator=(const connection &) = delete;
            ~connection()
            {
                std::cout << "one connection is close()" << std::endl;
            };
 
            /// Construct a connection with the given socket.
            explicit connection(asio::ip::tcp::socket socket,
                                asio::io_context &io_decoder, int connection_id,
                                std::shared_ptr<ModelDecoder> model_decoder);
 
            /// Start the first asynchronous operation for the connection.
            void start();
            std::shared_ptr<FUNASR_MESSAGE> &get_data_msg();
            void write_back(std::string str);
 
            // 处理100 Continue逻辑
            void handle_100_continue()
            {
                //  start_timer(5); // 5秒超时
 
                auto self = shared_from_this();
                const std::string response =
                    "HTTP/1.1 100 Continue\r\n"
                    "Connection: keep-alive\r\n\r\n";
 
                asio::async_write(socket_, asio::buffer(response),
                                  [this, self](asio::error_code ec, size_t)
                                  {
                                      if (ec)
                                          return handle_error(ec);
 
                                      state_ = State::ReadingHeaders;
 
                                      do_read();
                                  });
            }
 
            // 准备文件存储
            void prepare_body_handling()
            {
                if (!filename_.empty())
                {
                    sanitize_filename(filename_);
                    output_file_.open(filename_, std::ios::binary);
                    if (!output_file_)
                    {
                        std::cerr << "Failed to open: " << filename_ << "\n";
                        socket_.close();
                    }
                }
            }
 
            void finalize_request()
            {
                std::cout << "finalize_request" << std::endl;
                send_final_response();
            }
 
            void send_final_response()
            {
                const std::string response =
                    "HTTP/1.1 200 OK\r\n"
                    "Content-Length: 0\r\n\r\n";
                asio::write(socket_, asio::buffer(response));
                socket_.close();
            }
 
            void send_417_expectation_failed()
            {
                const std::string response =
                    "HTTP/1.1 417 Expectation Failed\r\n"
                    "Connection: close\r\n\r\n";
                asio::write(socket_, asio::buffer(response));
                socket_.close();
            }
 
            // 安全处理文件名
            static void sanitize_filename(std::string &name)
            {
                std::replace(name.begin(), name.end(), '/', '_');
                std::replace(name.begin(), name.end(), '\\', '_');
                name = name.substr(name.find_last_of(":") + 1); // 移除潜在路径
            }
 
            // 协议版本解析
            bool parse_http_version(const std::string &headers)
            {
                size_t start = headers.find("HTTP/");
                if (start == std::string::npos)
                    return false;
 
                start += 5;
                size_t dot = headers.find('.', start);
                if (dot == std::string::npos)
                    return false;
 
                try
                {
                    http_version_major_ = std::stoi(headers.substr(start, dot - start));
                    http_version_minor_ = std::stoi(headers.substr(dot + 1, 1));
                    return true;
                }
                catch (...)
                {
                    return false;
                }
            }
            // 头部解析
            bool try_parse_headers()
            {
 
                size_t header_end = received_data_.find("\r\n\r\n");
                if (header_end == std::string::npos)
                {
 
                    return false;
                }
 
                std::string headers = received_data_.substr(0, header_end);
 
                //  解析内容信息
                if (content_length_ <= 0)
                    content_length_ = parse_content_length(headers);
                // 解析HTTP版本
                if (!parse_http_version(headers))
                {
                    return false;
                }
 
                // 检查Expect头
                std::string continue100 = "Expect: 100-continue";
                size_t pos = headers.find(continue100);
                expect_100_continue_ = pos != std::string::npos;
 
                // 检查协议兼容性
                if (expect_100_continue_)
                {
 
                    headers.erase(pos, continue100.length());
 
                    received_data_ = headers;
                    state_ = State::SendingContinue;
                    if (http_version_minor_ < 1)
                        send_417_expectation_failed();
                    return true;
                }
 
                filename_ = parse_attachment_filename(headers);
 
                // 状态转移
                std::string ext = parese_file_ext(filename_);
 
                if (filename_.find(".wav") != std::string::npos)
                {
                    std::cout << "set wav_format=pcm, file_name=" << filename_
                              << std::endl;
                    data_msg->msg["wav_format"] = "pcm";
                }
                else
                {
                    std::cout << "set wav_format=" << ext << ", file_name=" << filename_
                              << std::endl;
                    data_msg->msg["wav_format"] = ext;
                }
                data_msg->msg["wav_name"] = filename_;
 
                state_ = State::ReadingBody;
                return true;
            }
 
            void parse_multipart_boundary()
            {
                size_t content_type_pos = received_data_.find("Content-Type: multipart/form-data");
                if (content_type_pos == std::string::npos)
                    return;
 
                size_t boundary_pos = received_data_.find("boundary=", content_type_pos);
                if (boundary_pos == std::string::npos)
                    return;
 
                boundary_pos += 9; // "boundary="长度
                size_t boundary_end = received_data_.find("\r\n", boundary_pos);
                boundary_ = received_data_.substr(boundary_pos, boundary_end - boundary_pos);
 
                // 清理boundary的引号
                if (boundary_.front() == '"' && boundary_.back() == '"')
                {
                    boundary_ = boundary_.substr(1, boundary_.size() - 2);
                }
            }
            // multipart 数据处理核心
            void process_multipart_data()
            {
                if (boundary_.empty())
                {
                    parse_multipart_boundary();
                    if (boundary_.empty())
                    {
                        std::cerr << "Invalid multipart format\n";
                        return;
                    }
                }
 
                while (true)
                {
                    if (!in_file_part_)
                    {
 
                        // 查找boundary起始
                        size_t boundary_pos = received_data_.find("--" + boundary_);
                        if (boundary_pos == std::string::npos)
                            break;
 
                        // 移动到part头部
                        size_t part_start = received_data_.find("\r\n\r\n", boundary_pos);
                        if (part_start == std::string::npos)
                            break;
 
                        part_start += 4; // 跳过空行
                        parse_part_headers(received_data_.substr(boundary_pos, part_start - boundary_pos));
 
                        received_data_.erase(0, part_start);
 
                        in_file_part_ = true;
                    }
                    else
                    {
                        // 查找boundary结束
 
                        size_t boundary_end = received_data_.find("\r\n--" + boundary_);
 
                        if (boundary_end == std::string::npos)
 
                            break;
 
                        // 写入内容
 
                        std::string tmpstr = received_data_.substr(0, boundary_end);
                        data_msg->samples->insert(data_msg->samples->end(), tmpstr.begin(), tmpstr.end());
 
                        received_data_.erase(0, boundary_end + 2); // 保留\r\n供下次解析
 
                        in_file_part_ = false;
                    }
                }
            }
            std::string parese_file_ext(std::string file_name)
            {
                int pos = file_name.rfind('.');
                std::string ext = "";
                if (pos != std::string::npos)
                    ext = file_name.substr(pos + 1);
 
                return ext;
            }
            // 解析part头部信息
            void parse_part_headers(const std::string &headers)
            {
                current_part_filename_.clear();
                expected_part_size_ = 0;
 
                // 解析文件名
                size_t filename_pos = headers.find("filename=\"");
                if (filename_pos != std::string::npos)
                {
                    filename_pos += 10;
                    size_t filename_end = headers.find('"', filename_pos);
                    current_part_filename_ = headers.substr(filename_pos, filename_end - filename_pos);
                    sanitize_filename(current_part_filename_);
                }
 
                // 解析Content-Length
                size_t cl_pos = headers.find("Content-Length: ");
                if (cl_pos != std::string::npos)
                {
                    cl_pos += 15;
                    size_t cl_end = headers.find("\r\n", cl_pos);
                    expected_part_size_ = std::stoull(headers.substr(cl_pos, cl_end - cl_pos));
                }
            }
 
        private:
            /// Perform an asynchronous read operation.
            void do_read();
            void handle_body();
            std::string parse_attachment_filename(const std::string &header);
            size_t parse_content_length(const std::string &header);
        
            void handle_error(asio::error_code ec);
            /// Perform an asynchronous write operation.
            void do_write();
 
            void do_decoder();
 
            void setup_timer();
 
            /// Socket for the connection.
            asio::ip::tcp::socket socket_;
 
            /// Buffer for incoming data.
            std::array<char, 8192> buffer_;
            /// for time out
            std::shared_ptr<asio::steady_timer> s_timer;
 
            std::shared_ptr<ModelDecoder> model_decoder;
 
            int connection_id = 0;
 
            /// The reply to be sent back to the client.
            reply reply_;
 
            asio::io_context &io_decoder;
 
            std::shared_ptr<FUNASR_MESSAGE> data_msg;
 
            std::mutex m_lock;
 
            std::shared_ptr<asio::io_context::strand> strand_;
 
 
 
            beasthttp::response_parser<beasthttp::string_body> parser_; // 渐进式解析器
 
            std::string received_data_;  // 累积接收的数据
            bool header_parsed_ = false; // 头部解析状态标记
            size_t content_length_ = 0;  // Content-Length 值
            enum class State
            {
                ReadingHeaders,
                SendingContinue,
                ReadingBody
            };
            bool expect_100_continue_ = false;
            State state_ = State::ReadingHeaders;
            std::string filename_;
            std::ofstream output_file_;
            int http_version_major_ = 1;
            int http_version_minor_ = 1;
            std::string boundary_ = "";
            bool in_file_part_ = false;
            std::string current_part_filename_;
            size_t expected_part_size_ = 0;
        };
 
        typedef std::shared_ptr<connection> connection_ptr;
 
    } // namespace server2
} // namespace http
 
#endif // HTTP_SERVER2_CONNECTION_HPP