COROIO: coroio/ws/ws.hpp Source File
COROIO
 
Loading...
Searching...
No Matches
ws.hpp
1#pragma once
2
3#include <coroio/sockutils.hpp>
4
5#if defined(__linux__)
6#include <arpa/inet.h>
7#include <endian.h>
8#define htonll(x) htobe64(x)
9#define ntohll(x) be64toh(x)
10#elif defined(__APPLE__)
11#include <arpa/inet.h>
12#elif defined(__FreeBSD__)
13#include <arpa/inet.h>
14#include <sys/endian.h>
15#define htonll(x) htobe64(x)
16#define ntohll(x) be64toh(x)
17#elif defined(_WIN32)
18#include <WinSock2.h>
19#if !defined(_WIN32_WINNT) || (_WIN32_WINNT < 0x0A00)
20inline uint64_t htonll(uint64_t value) {
21#if BYTE_ORDER == LITTLE_ENDIAN || defined(_M_IX86) || defined(_M_X64)
22 return ((uint64_t)htonl(static_cast<uint32_t>(value & 0xFFFFFFFF)) << 32) |
23 htonl(static_cast<uint32_t>(value >> 32));
24#else
25 return value;
26#endif
27}
28
29inline uint64_t ntohll(uint64_t value) {
30#if BYTE_ORDER == LITTLE_ENDIAN || defined(_M_IX86) || defined(_M_X64)
31 return ((uint64_t)ntohl(static_cast<uint32_t>(value & 0xFFFFFFFF)) << 32) |
32 ntohl(static_cast<uint32_t>(value >> 32));
33#else
34 return value;
35#endif
36}
37#endif // _WIN32_WINNT < 0x0A00
38#endif
39
40#include <random>
41
42namespace NNet
43{
44
45namespace NDetail {
46
47std::string GenerateWebSocketKey(std::random_device& rd);
48void CheckSecWebSocketAccept(const std::string& allServerHeaders, const std::string& clientKeyBase64);
49
50} // namespace detail
51
78template<typename TSocket>
80public:
86 explicit TWebSocket(TSocket& socket)
87 : Socket(socket)
88 , Reader(socket)
89 , Writer(socket)
90 { }
91
105 TFuture<void> Connect(const std::string& host, const std::string& path) {
106 auto key = NDetail::GenerateWebSocketKey(Rd);
107 std::string request =
108 "GET " + path + " HTTP/1.1\r\n"
109 "Host: " + host + "\r\n"
110 "User-Agent: coroio\r\n"
111 "Accept: */*\r\n"
112 "Connection: Upgrade\r\n"
113 "Upgrade: websocket\r\n"
114 "Sec-WebSocket-Key: " + key + "\r\n"
115 "Sec-WebSocket-Version: 13\r\n\r\n";
116
117 co_await Writer.Write(request.data(), request.size());
118
119 auto response = co_await Reader.ReadUntil("\r\n\r\n");
120
121 NDetail::CheckSecWebSocketAccept(response, key);
122
123 if (response.find("101 Switching Protocols") == std::string::npos) {
124 throw std::runtime_error("Failed to establish WebSocket connection");
125 }
126
127 co_return;
128 }
129
135 TFuture<void> SendText(std::string_view message) {
136 co_await SendFrame(0x1, message);
137 }
138
150 auto [opcode, payload] = co_await ReceiveFrame();
151 if (opcode != 0x1) {
152 throw std::runtime_error(
153 "Unexpected opcode: " +
154 std::to_string(opcode) +
155 " , expected text frame, got: '" +
156 std::string(payload) + "'");
157 }
158 co_return payload;
159 }
160
161private:
162 TSocket& Socket;
165 std::random_device Rd;
166 std::string Payload;
167 std::vector<uint8_t> Frame;
168
169 TFuture<void> SendFrame(uint8_t opcode, std::string_view payload) {
170 Frame.clear();
171 Frame.push_back(0x80 | opcode);
172
173 uint8_t maskingKey[4];
174 for (int i = 0; i < 4; ++i) {
175 maskingKey[i] = static_cast<uint8_t>(Rd());
176 }
177
178 if (payload.size() <= 125) {
179 Frame.push_back(0x80 | static_cast<uint8_t>(payload.size()));
180 } else if (payload.size() <= 0xFFFF) {
181 Frame.push_back(0x80 | 126);
182 uint16_t length = htons(static_cast<uint16_t>(payload.size()));
183 Frame.insert(Frame.end(), reinterpret_cast<uint8_t*>(&length), reinterpret_cast<uint8_t*>(&length) + 2);
184 } else {
185 Frame.push_back(0x80 | 127);
186 uint64_t length = htonll(payload.size());
187 Frame.insert(Frame.end(), reinterpret_cast<uint8_t*>(&length), reinterpret_cast<uint8_t*>(&length) + 8);
188 }
189
190 Frame.insert(Frame.end(), std::begin(maskingKey), std::end(maskingKey));
191
192 for (size_t i = 0; i < payload.size(); ++i) {
193 Frame.push_back(payload[i] ^ maskingKey[i % 4]);
194 }
195
196 co_await Writer.Write(Frame.data(), Frame.size());
197 co_return;
198 }
199
200 TFuture<std::pair<uint8_t, std::string_view>> ReceiveFrame() {
201 uint8_t header[2];
202 co_await Reader.Read(header, sizeof(header));
203
204 uint8_t opcode = header[0] & 0x0F;
205 bool masked = header[1] & 0x80;
206 uint64_t payloadLength = header[1] & 0x7F;
207
208 if (payloadLength == 126) {
209 uint16_t extendedLength;
210 co_await Reader.Read(&extendedLength, sizeof(extendedLength));
211 payloadLength = ntohs(extendedLength);
212 } else if (payloadLength == 127) {
213 uint64_t extendedLength;
214 co_await Reader.Read(&extendedLength, sizeof(extendedLength));
215 payloadLength = ntohll(extendedLength);
216 }
217
218 uint8_t mask[4] = {0};
219 if (masked) {
220 co_await Reader.Read(mask, sizeof(mask));
221 }
222
223 Payload.resize(payloadLength);
224 co_await Reader.Read(Payload.data(), Payload.size());
225
226 if (masked) {
227 for (size_t i = 0; i < Payload.size(); ++i) {
228 Payload[i] ^= mask[i % 4];
229 }
230 }
231
232 co_return {opcode, Payload};
233 }
234};
235
236} // namespace NNet {
High-level asynchronous socket for network communication.
Definition socket.hpp:367
Client-side WebSocket framing layer over an already-connected socket.
Definition ws.hpp:79
TWebSocket(TSocket &socket)
Wraps an already-connected socket with WebSocket framing.
Definition ws.hpp:86
TFuture< void > Connect(const std::string &host, const std::string &path)
Performs the HTTP → WebSocket upgrade handshake.
Definition ws.hpp:105
TFuture< void > SendText(std::string_view message)
Sends message as a masked WebSocket text frame (opcode 0x1).
Definition ws.hpp:135
TFuture< std::string_view > ReceiveText()
Receives the next WebSocket text frame.
Definition ws.hpp:149
A utility for reading data from a socket-like object, either a fixed number of bytes or until a speci...
Definition sockutils.hpp:76
TFuture< void > Read(void *data, size_t size)
Reads exactly size bytes and stores them into data.
Definition sockutils.hpp:101
A utility for writing data to a socket-like object.
Definition sockutils.hpp:239
TFuture< void > Write(const void *data, size_t size)
Writes exactly size bytes from data to the socket.
Definition sockutils.hpp:262
Owned coroutine handle that carries a result of type T.
Definition corochain.hpp:185