COROIO: coroio/ssl.hpp Source File
COROIO
 
Loading...
Searching...
No Matches
ssl.hpp
1#pragma once
2
3#if __has_include(<openssl/bio.h>)
4#define HAVE_OPENSSL
5
6#include <openssl/bio.h>
7#include <openssl/err.h>
8#include <openssl/pem.h>
9#include <openssl/ssl.h>
10
11#include <stdexcept>
12#include <functional>
13
14#include "base.hpp"
15#include "corochain.hpp"
16#include "sockutils.hpp"
17#include "promises.hpp"
18#include "socket.hpp"
19
20namespace NNet {
21
39 SSL_CTX* Ctx;
40 std::function<void(const char*)> LogFunc = {};
41
43 : Ctx(other.Ctx)
44 , LogFunc(other.LogFunc)
45 {
46 other.Ctx = nullptr;
47 }
48
49 ~TSslContext();
50
57 static TSslContext Client(const std::function<void(const char*)>& logFunc = {});
66 static TSslContext Server(const char* certfile, const char* keyfile, const std::function<void(const char*)>& logFunc = {});
75 static TSslContext ServerFromMem(const void* certfile, const void* keyfile, const std::function<void(const char*)>& logFunc = {});
76
77private:
78 TSslContext();
79};
80
101template<typename TSocket>
103public:
104 using TPoller = typename TSocket::TPoller;
105
116 : Socket(std::move(socket))
117 , Ctx(&ctx)
118 , Ssl(SSL_new(Ctx->Ctx))
119 , Rbio(BIO_new(BIO_s_mem()))
120 , Wbio(BIO_new(BIO_s_mem()))
121 {
122 SSL_set_bio(Ssl, Rbio, Wbio);
123 SSL_set_mode(Ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
124 SSL_set_mode(Ssl, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
125 }
126
127 TSslSocket(TSslSocket&& other)
128 {
129 *this = std::move(other);
130 }
131
132 TSslSocket& operator=(TSslSocket&& other) {
133 if (this != &other) {
134 Socket = std::move(other.Socket);
135 Ctx = other.Ctx;
136 Ssl = other.Ssl;
137 Rbio = other.Rbio;
138 Wbio = other.Wbio;
139 Handshake = other.Handshake;
140 other.Ssl = nullptr;
141 other.Rbio = other.Wbio = nullptr;
142 other.Handshake = nullptr;
143 }
144 return *this;
145 }
146
147 TSslSocket(const TSslSocket& ) = delete;
148 TSslSocket& operator=(const TSslSocket&) = delete;
149
150 TSslSocket() = default;
151
158 {
159 if (Ssl) { SSL_free(Ssl); }
160 if (Handshake) { Handshake.destroy(); }
161 }
162
170 void SslSetTlsExtHostName(const std::string& host) {
171 SSL_set_tlsext_host_name(Ssl, host.c_str());
172 }
173
183 auto underlying = std::move(co_await Socket.Accept());
184 auto socket = TSslSocket(std::move(underlying), *Ctx);
185 co_await socket.AcceptHandshake();
186 co_return std::move(socket);
187 }
188
197 assert(!Handshake);
198 SSL_set_accept_state(Ssl);
199 co_return co_await DoHandshake();
200 }
201
211 TFuture<void> Connect(const TAddress& address, TTime deadline = TTime::max()) {
212 assert(!Handshake);
213 co_await Socket.Connect(address, deadline);
214 SSL_set_connect_state(Ssl);
215 co_return co_await DoHandshake();
216 }
217
227 TFuture<ssize_t> ReadSome(void* data, size_t size) {
228 co_await WaitHandshake();
229
230 int n = SSL_read(Ssl, data, size);
231 int status;
232 if (n > 0) {
233 co_return n;
234 }
235
236 do {
237 co_await DoIO();
238 n = SSL_read(Ssl, data, size);
239 status = SSL_get_error(Ssl, n);
240 } while (n < 0 && (status == SSL_ERROR_WANT_READ || status == SSL_ERROR_WANT_WRITE));
241
242 co_return n;
243 }
244
254 TFuture<ssize_t> WriteSome(const void* data, size_t size) {
255 co_await WaitHandshake();
256
257 auto r = size;
258 const char* p = (const char*)data;
259 while (size != 0) {
260 auto n = SSL_write(Ssl, p, size);
261 auto status = SSL_get_error(Ssl, n);
262 if (!(status == SSL_ERROR_WANT_READ || status == SSL_ERROR_WANT_WRITE || status == SSL_ERROR_NONE)) {
263 throw std::runtime_error("SSL error: " + std::to_string(status));
264 }
265 if (n <= 0) {
266 throw std::runtime_error("SSL error: " + std::to_string(status));
267 }
268
269 co_await DoIO();
270
271 size -= n;
272 p += n;
273 }
274 co_return r;
275 }
276
282 auto Poller() {
283 return Socket.Poller();
284 }
285
286private:
287 TFuture<void> DoIO() {
288 char buf[1024];
289 int n;
290 while ((n = BIO_read(Wbio, buf, sizeof(buf))) > 0) {
291 co_await TByteWriter(Socket).Write(buf, n);
292 }
293 if (n < 0 && !BIO_should_retry(Wbio)) {
294 throw std::runtime_error("Cannot read Wbio");
295 }
296
297 if (SSL_get_error(Ssl, n) == SSL_ERROR_WANT_READ) {
298 auto size = co_await Socket.ReadSome(buf, sizeof(buf));
299 if (size == 0) {
300 throw std::runtime_error("Connection closed");
301 }
302 const char* p = buf;
303 while (size != 0) {
304 auto n = BIO_write(Rbio, p, size);
305 if (n <= 0) {
306 throw std::runtime_error("Cannot write Rbio");
307 }
308 size -= n;
309 p += n;
310 }
311 }
312
313 co_return;
314 }
315
316 TFuture<void> DoHandshake() {
317 int r;
318 LogState();
319 while ((r = SSL_do_handshake(Ssl)) != 1) {
320 LogState();
321 int status = SSL_get_error(Ssl, r);
322 if (status == SSL_ERROR_WANT_READ || status == SSL_ERROR_WANT_WRITE) {
323 co_await DoIO();
324 } else {
325 throw std::runtime_error("SSL error: " + std::to_string(r));
326 }
327 }
328
329 LogState();
330
331 co_await DoIO();
332
333 if (Ctx->LogFunc) {
334 Ctx->LogFunc("SSL Handshake established\n");
335 }
336
337 for (auto w : Waiters) {
338 w.resume();
339 }
340 Waiters.clear();
341
342 co_return;
343 }
344
345 void StartHandshake() {
346 assert(!Handshake);
347 Handshake = RunHandshake();
348 }
349
350 TVoidSuspendedTask RunHandshake() {
351 // TODO: catch exception
352 co_await DoHandshake();
353 co_return;
354 }
355
356 auto WaitHandshake() {
357 if (!SSL_is_init_finished(Ssl) && !Handshake) {
358 StartHandshake();
359 }
360 struct TAwaitable {
361 bool await_ready() {
362 return !handshake || handshake.done();
363 }
364
365 void await_suspend(std::coroutine_handle<> h) {
366 waiters->push_back(h);
367 }
368
369 void await_resume() { }
370
371 std::coroutine_handle<> handshake;
372 std::vector<std::coroutine_handle<>>* waiters;
373 };
374
375 return TAwaitable { Handshake, &Waiters };
376 };
377
378 void LogState() {
379 if (!Ctx->LogFunc) return;
380
381 char buf[1024];
382
383 const char * state = SSL_state_string_long(Ssl);
384 if (state != LastState) {
385 if (state) {
386 snprintf(buf, sizeof(buf), "SSL-STATE: %s", state);
387 Ctx->LogFunc(buf);
388 }
389 LastState = state;
390 }
391 }
392
393 TSocket Socket;
394 TSslContext* Ctx = nullptr;
395
396 SSL* Ssl = nullptr;
397 BIO* Rbio = nullptr;
398 BIO* Wbio = nullptr;
399
400 const char* LastState = nullptr;
401
402 std::coroutine_handle<> Handshake;
403 std::vector<std::coroutine_handle<>> Waiters;
404};
405
406} // namespace NNet
407
408#endif
409
A class representing an IPv4 or IPv6 address (with port).
Definition address.hpp:38
Base class for pollers managing asynchronous I/O events and timers.
Definition poller.hpp:52
auto ReadSome(void *buf, size_t size)
Asynchronously reads data from the socket into the provided buffer.
Definition socket.hpp:138
High-level asynchronous socket for network communication.
Definition socket.hpp:364
auto Connect(const TAddress &addr, TTime deadline=TTime::max())
Asynchronously connects to the specified address.
Definition socket.hpp:402
auto Accept()
Asynchronously accepts an incoming connection.
Definition socket.hpp:453
Implements an SSL/TLS layer on top of an underlying connection.
Definition ssl.hpp:102
void SslSetTlsExtHostName(const std::string &host)
Sets the TLS SNI (Server Name Indication) extension host name.
Definition ssl.hpp:170
auto Poller()
Returns the underlying poller.
Definition ssl.hpp:282
TFuture< ssize_t > WriteSome(const void *data, size_t size)
Asynchronously writes data to the SSL connection.
Definition ssl.hpp:254
TFuture< void > Connect(const TAddress &address, TTime deadline=TTime::max())
Initiates the client-side SSL handshake.
Definition ssl.hpp:211
~TSslSocket()
Destructor.
Definition ssl.hpp:157
TFuture< ssize_t > ReadSome(void *data, size_t size)
Asynchronously reads data from the SSL connection.
Definition ssl.hpp:227
TSslSocket(TSocket &&socket, TSslContext &ctx)
Constructs a TSslSocket from an underlying socket and an SSL context.
Definition ssl.hpp:115
TFuture< void > AcceptHandshake()
Performs the server-side SSL handshake.
Definition ssl.hpp:196
TFuture< TSslSocket< TSocket > > Accept()
Asynchronously accepts an incoming SSL connection.
Definition ssl.hpp:182
Implementation of a promise/future system for coroutines.
A utility for writing data to a socket-like object.
Definition sockutils.hpp:199
TFuture< void > Write(const void *data, size_t size)
Writes exactly size bytes from data to the socket.
Definition sockutils.hpp:222
Future type for coroutines returning a value of type T.
Definition corochain.hpp:182
Encapsulates an OpenSSL context (SSL_CTX) with optional logging.
Definition ssl.hpp:38
static TSslContext ServerFromMem(const void *certfile, const void *keyfile, const std::function< void(const char *)> &logFunc={})
Creates a server SSL context from in-memory certificate and key data.
Definition ssl.cpp:53
static TSslContext Client(const std::function< void(const char *)> &logFunc={})
Creates a client SSL context.
Definition ssl.cpp:24
static TSslContext Server(const char *certfile, const char *keyfile, const std::function< void(const char *)> &logFunc={})
Creates a server SSL context using certificate and key files.
Definition ssl.cpp:32
std::function< void(const char *)> LogFunc
Optional logging callback.
Definition ssl.hpp:40
SSL_CTX * Ctx
The underlying OpenSSL context.
Definition ssl.hpp:39