COROIO: coroio/ssl.hpp Source File
COROIO
 
Loading...
Searching...
No Matches
ssl.hpp
1#pragma once
2
3#include <openssl/bio.h>
4#include <openssl/err.h>
5#include <openssl/pem.h>
6#include <openssl/ssl.h>
7
8#include <stdexcept>
9#include <functional>
10
11#include "base.hpp"
12#include "corochain.hpp"
13#include "sockutils.hpp"
14#include "promises.hpp"
15#include "socket.hpp"
16
17namespace NNet {
18
35struct TSslContext {
36 SSL_CTX* Ctx;
37 std::function<void(const char*)> LogFunc = {};
38
40 : Ctx(other.Ctx)
41 , LogFunc(other.LogFunc)
42 {
43 other.Ctx = nullptr;
44 }
45
46 ~TSslContext();
47
54 static TSslContext Client(const std::function<void(const char*)>& logFunc = {});
63 static TSslContext Server(const char* certfile, const char* keyfile, const std::function<void(const char*)>& logFunc = {});
72 static TSslContext ServerFromMem(const void* certfile, const void* keyfile, const std::function<void(const char*)>& logFunc = {});
73
74private:
76};
77
98template<typename TSocket>
100public:
101 using TPoller = typename TSocket::TPoller;
102
113 : Socket(std::move(socket))
114 , Ctx(&ctx)
115 , Ssl(SSL_new(Ctx->Ctx))
116 , Rbio(BIO_new(BIO_s_mem()))
117 , Wbio(BIO_new(BIO_s_mem()))
118 {
119 SSL_set_bio(Ssl, Rbio, Wbio);
120 SSL_set_mode(Ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
121 SSL_set_mode(Ssl, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
122 }
123
124 TSslSocket(TSslSocket&& other)
125 {
126 *this = std::move(other);
127 }
128
129 TSslSocket& operator=(TSslSocket&& other) {
130 if (this != &other) {
131 Socket = std::move(other.Socket);
132 Ctx = other.Ctx;
133 Ssl = other.Ssl;
134 Rbio = other.Rbio;
135 Wbio = other.Wbio;
136 Handshake = other.Handshake;
137 other.Ssl = nullptr;
138 other.Rbio = other.Wbio = nullptr;
139 other.Handshake = nullptr;
140 }
141 return *this;
142 }
143
144 TSslSocket(const TSslSocket& ) = delete;
145 TSslSocket& operator=(const TSslSocket&) = delete;
146
147 TSslSocket() = default;
148
155 {
156 if (Ssl) { SSL_free(Ssl); }
157 if (Handshake) { Handshake.destroy(); }
158 }
159
167 void SslSetTlsExtHostName(const std::string& host) {
168 SSL_set_tlsext_host_name(Ssl, host.c_str());
169 }
170
180 auto underlying = std::move(co_await Socket.Accept());
181 auto socket = TSslSocket(std::move(underlying), *Ctx);
182 co_await socket.AcceptHandshake();
183 co_return std::move(socket);
184 }
185
194 assert(!Handshake);
195 SSL_set_accept_state(Ssl);
196 co_return co_await DoHandshake();
197 }
198
208 TFuture<void> Connect(const TAddress& address, TTime deadline = TTime::max()) {
209 assert(!Handshake);
210 co_await Socket.Connect(address, deadline);
211 SSL_set_connect_state(Ssl);
212 co_return co_await DoHandshake();
213 }
214
224 TFuture<ssize_t> ReadSome(void* data, size_t size) {
225 co_await WaitHandshake();
226
227 int n = SSL_read(Ssl, data, size);
228 int status;
229 if (n > 0) {
230 co_return n;
231 }
232
233 do {
234 co_await DoIO();
235 n = SSL_read(Ssl, data, size);
236 status = SSL_get_error(Ssl, n);
237 } while (n < 0 && (status == SSL_ERROR_WANT_READ || status == SSL_ERROR_WANT_WRITE));
238
239 co_return n;
240 }
241
251 TFuture<ssize_t> WriteSome(const void* data, size_t size) {
252 co_await WaitHandshake();
253
254 auto r = size;
255 const char* p = (const char*)data;
256 while (size != 0) {
257 auto n = SSL_write(Ssl, p, size);
258 auto status = SSL_get_error(Ssl, n);
259 if (!(status == SSL_ERROR_WANT_READ || status == SSL_ERROR_WANT_WRITE || status == SSL_ERROR_NONE)) {
260 throw std::runtime_error("SSL error: " + std::to_string(status));
261 }
262 if (n <= 0) {
263 throw std::runtime_error("SSL error: " + std::to_string(status));
264 }
265
266 co_await DoIO();
267
268 size -= n;
269 p += n;
270 }
271 co_return r;
272 }
273
279 auto Poller() {
280 return Socket.Poller();
281 }
282
283private:
284 TFuture<void> DoIO() {
285 char buf[1024];
286 int n;
287 while ((n = BIO_read(Wbio, buf, sizeof(buf))) > 0) {
288 co_await TByteWriter(Socket).Write(buf, n);
289 }
290 if (n < 0 && !BIO_should_retry(Wbio)) {
291 throw std::runtime_error("Cannot read Wbio");
292 }
293
294 if (SSL_get_error(Ssl, n) == SSL_ERROR_WANT_READ) {
295 auto size = co_await Socket.ReadSome(buf, sizeof(buf));
296 if (size == 0) {
297 throw std::runtime_error("Connection closed");
298 }
299 const char* p = buf;
300 while (size != 0) {
301 auto n = BIO_write(Rbio, p, size);
302 if (n <= 0) {
303 throw std::runtime_error("Cannot write Rbio");
304 }
305 size -= n;
306 p += n;
307 }
308 }
309
310 co_return;
311 }
312
313 TFuture<void> DoHandshake() {
314 int r;
315 LogState();
316 while ((r = SSL_do_handshake(Ssl)) != 1) {
317 LogState();
318 int status = SSL_get_error(Ssl, r);
319 if (status == SSL_ERROR_WANT_READ || status == SSL_ERROR_WANT_WRITE) {
320 co_await DoIO();
321 } else {
322 throw std::runtime_error("SSL error: " + std::to_string(r));
323 }
324 }
325
326 LogState();
327
328 co_await DoIO();
329
330 if (Ctx->LogFunc) {
331 Ctx->LogFunc("SSL Handshake established\n");
332 }
333
334 for (auto w : Waiters) {
335 w.resume();
336 }
337 Waiters.clear();
338
339 co_return;
340 }
341
342 void StartHandshake() {
343 assert(!Handshake);
344 Handshake = RunHandshake();
345 }
346
347 TVoidSuspendedTask RunHandshake() {
348 // TODO: catch exception
349 co_await DoHandshake();
350 co_return;
351 }
352
353 auto WaitHandshake() {
354 if (!SSL_is_init_finished(Ssl) && !Handshake) {
355 StartHandshake();
356 }
357 struct TAwaitable {
358 bool await_ready() {
359 return !handshake || handshake.done();
360 }
361
362 void await_suspend(std::coroutine_handle<> h) {
363 waiters->push_back(h);
364 }
365
366 void await_resume() { }
367
368 std::coroutine_handle<> handshake;
369 std::vector<std::coroutine_handle<>>* waiters;
370 };
371
372 return TAwaitable { Handshake, &Waiters };
373 };
374
375 void LogState() {
376 if (!Ctx->LogFunc) return;
377
378 char buf[1024];
379
380 const char * state = SSL_state_string_long(Ssl);
381 if (state != LastState) {
382 if (state) {
383 snprintf(buf, sizeof(buf), "SSL-STATE: %s", state);
384 Ctx->LogFunc(buf);
385 }
386 LastState = state;
387 }
388 }
389
390 TSocket Socket;
391 TSslContext* Ctx = nullptr;
392
393 SSL* Ssl = nullptr;
394 BIO* Rbio = nullptr;
395 BIO* Wbio = nullptr;
396
397 const char* LastState = nullptr;
398
399 std::coroutine_handle<> Handshake;
400 std::vector<std::coroutine_handle<>> Waiters;
401};
402
403} // namespace NNet
A class representing an IPv4 or IPv6 address (with port).
Definition address.hpp:30
High-level asynchronous socket for network communication.
Definition socket.hpp:364
Implements an SSL/TLS layer on top of an underlying connection.
Definition ssl.hpp:99
void SslSetTlsExtHostName(const std::string &host)
Sets the TLS SNI (Server Name Indication) extension host name.
Definition ssl.hpp:167
auto Poller()
Returns the underlying poller.
Definition ssl.hpp:279
TFuture< ssize_t > WriteSome(const void *data, size_t size)
Asynchronously writes data to the SSL connection.
Definition ssl.hpp:251
TFuture< void > Connect(const TAddress &address, TTime deadline=TTime::max())
Initiates the client-side SSL handshake.
Definition ssl.hpp:208
~TSslSocket()
Destructor.
Definition ssl.hpp:154
TFuture< ssize_t > ReadSome(void *data, size_t size)
Asynchronously reads data from the SSL connection.
Definition ssl.hpp:224
TSslSocket(TSocket &&socket, TSslContext &ctx)
Constructs a TSslSocket from an underlying socket and an SSL context.
Definition ssl.hpp:112
TFuture< void > AcceptHandshake()
Performs the server-side SSL handshake.
Definition ssl.hpp:193
TFuture< TSslSocket< TSocket > > Accept()
Asynchronously accepts an incoming SSL connection.
Definition ssl.hpp:179
Implementation of a promise/future system for coroutines.
A utility for writing data to a socket-like object.
Definition sockutils.hpp:190
TFuture< void > Write(const void *data, size_t size)
Writes exactly size bytes from data to the socket.
Definition sockutils.hpp:213
Future type for coroutines returning a value of type T.
Definition corochain.hpp:177
Encapsulates an OpenSSL context (SSL_CTX) with optional logging.
Definition ssl.hpp:35
static TSslContext Client(const std::function< void(const char *)> &logFunc={})
Creates a client SSL context.
static TSslContext ServerFromMem(const void *certfile, const void *keyfile, const std::function< void(const char *)> &logFunc={})
Creates a server SSL context from in-memory certificate and key data.
static TSslContext Server(const char *certfile, const char *keyfile, const std::function< void(const char *)> &logFunc={})
Creates a server SSL context using certificate and key files.
std::function< void(const char *)> LogFunc
Optional logging callback.
Definition ssl.hpp:37
SSL_CTX * Ctx
The underlying OpenSSL context.
Definition ssl.hpp:36