COROIO: coroio/ws.hpp Source File
COROIO
 
All Classes Files Functions Variables Typedefs Pages
Loading...
Searching...
No Matches
ws.hpp
1#pragma once
2
3#include "sockutils.hpp"
4
5#if defined(__linux__)
6#include <arpa/inet.h>
7#include <endian.h>
8#define htonll(x) htobe64(x)
9#define ntohll(x) be64toh(x)
10#elif defined(__APPLE__)
11#include <arpa/inet.h>
12#elif defined(_WIN32)
13#include <WinSock2.h>
14#endif
15
16#include <random>
17
18namespace NNet
19{
20
21std::string GenerateWebSocketKey(std::random_device& rd);
22void CheckSecWebSocketAccept(const std::string& allServerHeaders, const std::string& clientKeyBase64);
23
60template<typename TSocket>
62public:
68 explicit TWebSocket(TSocket& socket)
69 : Socket(socket)
70 , Reader(socket)
71 , Writer(socket)
72 { }
73
86 TFuture<void> Connect(const std::string& host, const std::string& path) {
87 auto key = GenerateWebSocketKey(Rd);
88 std::string request =
89 "GET " + path + " HTTP/1.1\r\n"
90 "Host: " + host + "\r\n"
91 "User-Agent: coroio\r\n"
92 "Accept: */*\r\n"
93 "Connection: Upgrade\r\n"
94 "Upgrade: websocket\r\n"
95 "Sec-WebSocket-Key: " + key + "\r\n"
96 "Sec-WebSocket-Version: 13\r\n\r\n";
97
98 co_await Writer.Write(request.data(), request.size());
99
100 auto response = co_await Reader.ReadUntil("\r\n\r\n");
101
102 CheckSecWebSocketAccept(response, key);
103
104 if (response.find("101 Switching Protocols") == std::string::npos) {
105 throw std::runtime_error("Failed to establish WebSocket connection");
106 }
107
108 co_return;
109 }
110
117 TFuture<void> SendText(std::string_view message) {
118 co_await SendFrame(0x1, message);
119 }
120
130 auto [opcode, payload] = co_await ReceiveFrame();
131 if (opcode != 0x1) {
132 throw std::runtime_error(
133 "Unexpected opcode: " +
134 std::to_string(opcode) +
135 " , expected text frame, got: '" +
136 std::string(payload) + "'");
137 }
138 co_return payload;
139 }
140
141private:
142 TSocket& Socket;
145 std::random_device Rd;
146 std::string Payload;
147 std::vector<uint8_t> Frame;
148
149 TFuture<void> SendFrame(uint8_t opcode, std::string_view payload) {
150 Frame.clear();
151 Frame.push_back(0x80 | opcode);
152
153 uint8_t maskingKey[4];
154 for (int i = 0; i < 4; ++i) {
155 maskingKey[i] = static_cast<uint8_t>(Rd());
156 }
157
158 if (payload.size() <= 125) {
159 Frame.push_back(0x80 | static_cast<uint8_t>(payload.size()));
160 } else if (payload.size() <= 0xFFFF) {
161 Frame.push_back(0x80 | 126);
162 uint16_t length = htons(static_cast<uint16_t>(payload.size()));
163 Frame.insert(Frame.end(), reinterpret_cast<uint8_t*>(&length), reinterpret_cast<uint8_t*>(&length) + 2);
164 } else {
165 Frame.push_back(0x80 | 127);
166 uint64_t length = htonll(payload.size());
167 Frame.insert(Frame.end(), reinterpret_cast<uint8_t*>(&length), reinterpret_cast<uint8_t*>(&length) + 8);
168 }
169
170 Frame.insert(Frame.end(), std::begin(maskingKey), std::end(maskingKey));
171
172 for (size_t i = 0; i < payload.size(); ++i) {
173 Frame.push_back(payload[i] ^ maskingKey[i % 4]);
174 }
175
176 co_await Writer.Write(Frame.data(), Frame.size());
177 co_return;
178 }
179
180 TFuture<std::pair<uint8_t, std::string_view>> ReceiveFrame() {
181 uint8_t header[2];
182 co_await Reader.Read(header, sizeof(header));
183
184 uint8_t opcode = header[0] & 0x0F;
185 bool masked = header[1] & 0x80;
186 uint64_t payloadLength = header[1] & 0x7F;
187
188 if (payloadLength == 126) {
189 uint16_t extendedLength;
190 co_await Reader.Read(&extendedLength, sizeof(extendedLength));
191 payloadLength = ntohs(extendedLength);
192 } else if (payloadLength == 127) {
193 uint64_t extendedLength;
194 co_await Reader.Read(&extendedLength, sizeof(extendedLength));
195 payloadLength = ntohll(extendedLength);
196 }
197
198 uint8_t mask[4] = {0};
199 if (masked) {
200 co_await Reader.Read(mask, sizeof(mask));
201 }
202
203 Payload.resize(payloadLength);
204 co_await Reader.Read(Payload.data(), Payload.size());
205
206 if (masked) {
207 for (size_t i = 0; i < Payload.size(); ++i) {
208 Payload[i] ^= mask[i % 4];
209 }
210 }
211
212 co_return {opcode, Payload};
213 }
214};
215
216} // namespace NNet {
High-level asynchronous socket for network communication.
Definition socket.hpp:364
TWebSocket(TSocket &socket)
Constructs a WebSocket instance wrapping the provided socket.
Definition ws.hpp:68
TFuture< void > Connect(const std::string &host, const std::string &path)
Initiates the WebSocket handshake.
Definition ws.hpp:86
TFuture< void > SendText(std::string_view message)
Sends a text message as a WebSocket frame.
Definition ws.hpp:117
TFuture< std::string_view > ReceiveText()
Receives a text message from the WebSocket.
Definition ws.hpp:129
A utility for reading data from a socket-like object, either a fixed number of bytes or until a speci...
Definition sockutils.hpp:55
A utility for writing data to a socket-like object.
Definition sockutils.hpp:190
Future type for coroutines returning a value of type T.
Definition corochain.hpp:177