COROIO: coroio/ssl.hpp Source File
COROIO
 
Loading...
Searching...
No Matches
ssl.hpp
1#pragma once
2
3#if __has_include(<openssl/bio.h>)
4#define HAVE_OPENSSL
5
6#include <openssl/bio.h>
7#include <openssl/err.h>
8#include <openssl/pem.h>
9#include <openssl/ssl.h>
10
11#include <stdexcept>
12#include <functional>
13
14#include "base.hpp"
15#include "corochain.hpp"
16#include "sockutils.hpp"
17#include "promises.hpp"
18#include "socket.hpp"
19
20namespace NNet {
21
40 SSL_CTX* Ctx;
41 std::function<void(const char*)> LogFunc = {};
42
44 : Ctx(other.Ctx)
45 , LogFunc(other.LogFunc)
46 {
47 other.Ctx = nullptr;
48 }
49
50 ~TSslContext();
51
57 static TSslContext Client(const std::function<void(const char*)>& logFunc = {});
58
66 static TSslContext Server(const char* certfile, const char* keyfile, const std::function<void(const char*)>& logFunc = {});
67
77 static TSslContext ServerFromMem(const void* certfile, const void* keyfile, const std::function<void(const char*)>& logFunc = {});
78
79private:
80 TSslContext();
81};
82
107template<typename TSocket>
109public:
110 using TPoller = typename TSocket::TPoller;
111
119 : Socket(std::move(socket))
120 , Ctx(&ctx)
121 , Ssl(SSL_new(Ctx->Ctx))
122 , Rbio(BIO_new(BIO_s_mem()))
123 , Wbio(BIO_new(BIO_s_mem()))
124 {
125 SSL_set_bio(Ssl, Rbio, Wbio);
126 SSL_set_mode(Ssl, SSL_MODE_ENABLE_PARTIAL_WRITE);
127 SSL_set_mode(Ssl, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
128 }
129
130 TSslSocket(TSslSocket&& other)
131 {
132 *this = std::move(other);
133 }
134
135 TSslSocket& operator=(TSslSocket&& other) {
136 if (this != &other) {
137 Socket = std::move(other.Socket);
138 Ctx = other.Ctx;
139 Ssl = other.Ssl;
140 Rbio = other.Rbio;
141 Wbio = other.Wbio;
142 Handshake = other.Handshake;
143 other.Ssl = nullptr;
144 other.Rbio = other.Wbio = nullptr;
145 other.Handshake = nullptr;
146 }
147 return *this;
148 }
149
150 TSslSocket(const TSslSocket& ) = delete;
151 TSslSocket& operator=(const TSslSocket&) = delete;
152
153 TSslSocket() = default;
154
157 {
158 if (Ssl) { SSL_free(Ssl); }
159 if (Handshake) { Handshake.destroy(); }
160 }
161
170 void SslSetTlsExtHostName(const std::string& host) {
171 SSL_set_tlsext_host_name(Ssl, host.c_str());
172 }
173
183 auto underlying = std::move(co_await Socket.Accept());
184 auto socket = TSslSocket(std::move(underlying), *Ctx);
185 co_await socket.AcceptHandshake();
186 co_return std::move(socket);
187 }
188
196 assert(!Handshake);
197 SSL_set_accept_state(Ssl);
198 co_return co_await DoHandshake();
199 }
200
211 TFuture<void> Connect(const TAddress& address, TTime deadline = TTime::max()) {
212 assert(!Handshake);
213 co_await Socket.Connect(address, deadline);
214 SSL_set_connect_state(Ssl);
215 co_return co_await DoHandshake();
216 }
217
225 TFuture<ssize_t> ReadSome(void* data, size_t size) {
226 co_await WaitHandshake();
227
228 int n = SSL_read(Ssl, data, size);
229 int status;
230 if (n > 0) {
231 co_return n;
232 }
233
234 do {
235 co_await DoIO();
236 n = SSL_read(Ssl, data, size);
237 status = SSL_get_error(Ssl, n);
238 } while (n < 0 && (status == SSL_ERROR_WANT_READ || status == SSL_ERROR_WANT_WRITE));
239
240 co_return n;
241 }
242
250 TFuture<ssize_t> WriteSome(const void* data, size_t size) {
251 co_await WaitHandshake();
252
253 auto r = size;
254 const char* p = (const char*)data;
255 while (size != 0) {
256 auto n = SSL_write(Ssl, p, size);
257 auto status = SSL_get_error(Ssl, n);
258 if (!(status == SSL_ERROR_WANT_READ || status == SSL_ERROR_WANT_WRITE || status == SSL_ERROR_NONE)) {
259 throw std::runtime_error("SSL error: " + std::to_string(status));
260 }
261 if (n <= 0) {
262 throw std::runtime_error("SSL error: " + std::to_string(status));
263 }
264
265 co_await DoIO();
266
267 size -= n;
268 p += n;
269 }
270 co_return r;
271 }
272
274 auto Poller() {
275 return Socket.Poller();
276 }
277
278private:
279 TFuture<void> DoIO() {
280 char buf[1024];
281 int n;
282 while ((n = BIO_read(Wbio, buf, sizeof(buf))) > 0) {
283 co_await TByteWriter(Socket).Write(buf, n);
284 }
285 if (n < 0 && !BIO_should_retry(Wbio)) {
286 throw std::runtime_error("Cannot read Wbio");
287 }
288
289 if (SSL_get_error(Ssl, n) == SSL_ERROR_WANT_READ) {
290 auto size = co_await Socket.ReadSome(buf, sizeof(buf));
291 if (size == 0) {
292 throw std::runtime_error("Connection closed");
293 }
294 const char* p = buf;
295 while (size != 0) {
296 auto n = BIO_write(Rbio, p, size);
297 if (n <= 0) {
298 throw std::runtime_error("Cannot write Rbio");
299 }
300 size -= n;
301 p += n;
302 }
303 }
304
305 co_return;
306 }
307
308 TFuture<void> DoHandshake() {
309 int r;
310 LogState();
311 while ((r = SSL_do_handshake(Ssl)) != 1) {
312 LogState();
313 int status = SSL_get_error(Ssl, r);
314 if (status == SSL_ERROR_WANT_READ || status == SSL_ERROR_WANT_WRITE) {
315 co_await DoIO();
316 } else {
317 throw std::runtime_error("SSL error: " + std::to_string(r));
318 }
319 }
320
321 LogState();
322
323 co_await DoIO();
324
325 if (Ctx->LogFunc) {
326 Ctx->LogFunc("SSL Handshake established\n");
327 }
328
329 for (auto w : Waiters) {
330 w.resume();
331 }
332 Waiters.clear();
333
334 co_return;
335 }
336
337 void StartHandshake() {
338 assert(!Handshake);
339 Handshake = RunHandshake();
340 }
341
342 TVoidSuspendedTask RunHandshake() {
343 // TODO: catch exception
344 co_await DoHandshake();
345 co_return;
346 }
347
348 auto WaitHandshake() {
349 if (!SSL_is_init_finished(Ssl) && !Handshake) {
350 StartHandshake();
351 }
352 struct TAwaitable {
353 bool await_ready() {
354 return !handshake || handshake.done();
355 }
356
357 void await_suspend(std::coroutine_handle<> h) {
358 waiters->push_back(h);
359 }
360
361 void await_resume() { }
362
363 std::coroutine_handle<> handshake;
364 std::vector<std::coroutine_handle<>>* waiters;
365 };
366
367 return TAwaitable { Handshake, &Waiters };
368 };
369
370 void LogState() {
371 if (!Ctx->LogFunc) return;
372
373 char buf[1024];
374
375 const char * state = SSL_state_string_long(Ssl);
376 if (state != LastState) {
377 if (state) {
378 snprintf(buf, sizeof(buf), "SSL-STATE: %s", state);
379 Ctx->LogFunc(buf);
380 }
381 LastState = state;
382 }
383 }
384
385 TSocket Socket;
386 TSslContext* Ctx = nullptr;
387
388 SSL* Ssl = nullptr;
389 BIO* Rbio = nullptr;
390 BIO* Wbio = nullptr;
391
392 const char* LastState = nullptr;
393
394 std::coroutine_handle<> Handshake;
395 std::vector<std::coroutine_handle<>> Waiters;
396};
397
398} // namespace NNet
399
400#endif
401
A class representing an IPv4 or IPv6 address (with port).
Definition address.hpp:38
Backend-independent base for I/O pollers.
Definition poller.hpp:38
auto ReadSome(void *buf, size_t size)
Asynchronously reads data from the socket into the provided buffer.
Definition socket.hpp:139
High-level asynchronous socket for network communication.
Definition socket.hpp:367
auto Connect(const TAddress &addr, TTime deadline=TTime::max())
Asynchronously connects to the specified address.
Definition socket.hpp:405
auto Accept()
Asynchronously accepts an incoming connection.
Definition socket.hpp:456
TLS layer over any connected socket, exposing the same ReadSome/WriteSome interface.
Definition ssl.hpp:108
void SslSetTlsExtHostName(const std::string &host)
Sets the TLS SNI host name sent in the ClientHello.
Definition ssl.hpp:170
auto Poller()
Returns the poller associated with the underlying socket.
Definition ssl.hpp:274
TFuture< ssize_t > WriteSome(const void *data, size_t size)
Encrypts and sends all size bytes from data.
Definition ssl.hpp:250
TFuture< void > Connect(const TAddress &address, TTime deadline=TTime::max())
TCP-connects to address and performs the client-side TLS handshake.
Definition ssl.hpp:211
~TSslSocket()
Frees the SSL instance, associated BIOs, and any in-progress handshake coroutine.
Definition ssl.hpp:156
TFuture< ssize_t > ReadSome(void *data, size_t size)
Reads up to size decrypted bytes into data.
Definition ssl.hpp:225
TSslSocket(TSocket &&socket, TSslContext &ctx)
Constructs a TSslSocket, taking ownership of the underlying socket.
Definition ssl.hpp:118
TFuture< void > AcceptHandshake()
Performs the server-side TLS handshake on an already-accepted TCP socket.
Definition ssl.hpp:195
TFuture< TSslSocket< TSocket > > Accept()
Accepts a TCP connection and performs the server-side TLS handshake.
Definition ssl.hpp:182
Implementation of a promise/future system for coroutines.
A utility for writing data to a socket-like object.
Definition sockutils.hpp:239
TFuture< void > Write(const void *data, size_t size)
Writes exactly size bytes from data to the socket.
Definition sockutils.hpp:262
Owned coroutine handle that carries a result of type T.
Definition corochain.hpp:185
Owns an OpenSSL SSL_CTX and optional log callback.
Definition ssl.hpp:39
static TSslContext ServerFromMem(const void *certfile, const void *keyfile, const std::function< void(const char *)> &logFunc={})
Creates a TLS server context from PEM data already in memory.
Definition ssl.cpp:53
static TSslContext Client(const std::function< void(const char *)> &logFunc={})
Creates a TLS client context (no certificate required).
Definition ssl.cpp:24
static TSslContext Server(const char *certfile, const char *keyfile, const std::function< void(const char *)> &logFunc={})
Creates a TLS server context from PEM files on disk.
Definition ssl.cpp:32
std::function< void(const char *)> LogFunc
Optional logging callback.
Definition ssl.hpp:41
SSL_CTX * Ctx
The underlying OpenSSL context.
Definition ssl.hpp:40