COROIO: coroio/socket.hpp Source File
COROIO
 
Loading...
Searching...
No Matches
socket.hpp
1#pragma once
2
3#ifdef _WIN32
4#include <winsock2.h>
5#include <ws2tcpip.h>
6#include <io.h>
7#else
8#include <arpa/inet.h>
9#include <sys/stat.h>
10#include <sys/socket.h>
11#include <sys/types.h>
12#include <netinet/in.h>
13#include <unistd.h>
14#include <fcntl.h>
15#endif
16
17#include <optional>
18#include <variant>
19
20#include "poller.hpp"
21#include "address.hpp"
22
23namespace NNet {
24
25template<typename T> class TSocketBase;
26
40template<>
41class TSocketBase<void> {
42public:
47 TPollerBase* Poller() { return Poller_; }
48
49protected:
50 TSocketBase(TPollerBase& poller, int domain, int type);
51 TSocketBase(int fd, TPollerBase& poller);
52 TSocketBase() = default;
53
61 int Create(int domain, int type);
70 int Setup(int s);
71
72 TPollerBase* Poller_ = nullptr;
73 int Fd_ = -1;
74};
75
98template<typename TSockOps>
99class TSocketBase: public TSocketBase<void> {
100protected:
108 TSocketBase(TPollerBase& poller, int domain, int type): TSocketBase<void>(poller, domain, type)
109 { }
116 TSocketBase(int fd, TPollerBase& poller): TSocketBase<void>(fd, poller)
117 { }
118
119 TSocketBase() = default;
120 TSocketBase(const TSocketBase& other) = delete;
121 TSocketBase& operator=(TSocketBase& other) const = delete;
122
123 ~TSocketBase() {
124 Close();
125 }
126
127public:
139 auto ReadSome(void* buf, size_t size) {
140 struct TAwaitableRead: public TAwaitable<TAwaitableRead> {
141 void run() {
142 this->ret = TSockOps::read(this->fd, this->b, this->s);
143 }
144
145 void await_suspend(std::coroutine_handle<> h) {
146 this->poller->AddRead(this->fd, h);
147 }
148 };
149 return TAwaitableRead{Poller_,Fd_,buf,size};
150 }
162 auto ReadSomeYield(void* buf, size_t size) {
163 struct TAwaitableRead: public TAwaitable<TAwaitableRead> {
164 bool await_ready() {
165 return (this->ready = false);
166 }
167
168 void run() {
169 this->ret = TSockOps::read(this->fd, this->b, this->s);
170 }
171
172 void await_suspend(std::coroutine_handle<> h) {
173 this->poller->AddRead(this->fd, h);
174 }
175 };
176 return TAwaitableRead{Poller_,Fd_,buf,size};
177 }
189 auto WriteSome(const void* buf, size_t size) {
190 struct TAwaitableWrite: public TAwaitable<TAwaitableWrite> {
191 void run() {
192 this->ret = TSockOps::write(this->fd, this->b, this->s);
193 }
194
195 void await_suspend(std::coroutine_handle<> h) {
196 this->poller->AddWrite(this->fd, h);
197 }
198 };
199 return TAwaitableWrite{Poller_,Fd_,const_cast<void*>(buf),size};
200 }
211 auto WriteSomeYield(const void* buf, size_t size) {
212 struct TAwaitableWrite: public TAwaitable<TAwaitableWrite> {
213 bool await_ready() {
214 return (this->ready = false);
215 }
216
217 void run() {
218 this->ret = TSockOps::write(this->fd, this->b, this->s);
219 }
220
221 void await_suspend(std::coroutine_handle<> h) {
222 this->poller->AddWrite(this->fd, h);
223 }
224 };
225 return TAwaitableWrite{Poller_,Fd_,const_cast<void*>(buf),size};
226 }
235 auto Monitor() {
236 struct TAwaitableClose: public TAwaitable<TAwaitableClose> {
237 void run() {
238 this->ret = true;
239 }
240
241 void await_suspend(std::coroutine_handle<> h) {
242 this->poller->AddRemoteHup(this->fd, h);
243 }
244 };
245 return TAwaitableClose{Poller_,Fd_};
246 }
253 void Close()
254 {
255 if (Fd_ >= 0) {
256 TSockOps::close(Fd_);
257 Poller_->RemoveEvent(Fd_);
258 Fd_ = -1;
259 }
260 }
261
262protected:
263 template<typename T>
264 struct TAwaitable {
265 bool await_ready() {
266 SafeRun();
267 return (ready = (ret >= 0));
268 }
269
270 int await_resume() {
271 if (!ready) {
272 SafeRun();
273 }
274 return ret;
275 }
276
277 void SafeRun() {
278 ((T*)this)->run();
279#ifdef _WIN32
280 if (ret < 0 && WSAGetLastError() != WSAEWOULDBLOCK ) {
281 throw std::system_error(WSAGetLastError(), std::generic_category());
282 }
283#else
284 if (ret < 0 && !(errno==EINTR||errno==EAGAIN||errno==EINPROGRESS)) {
285 throw std::system_error(errno, std::generic_category());
286 }
287#endif
288 }
289
290 TPollerBase* poller = nullptr;
291 int fd = -1;
292 void* b = nullptr; size_t s = 0;
293 int ret = -1;
294 bool ready = false;
295 };
296};
297
298class TFileOps {
299public:
300 static auto read(int fd, void* buf, size_t count) {
301 return ::read(fd, buf, count);
302 }
303
304 static auto write(int fd, const void* buf, size_t count) {
305 return ::write(fd, buf, count);
306 }
307
308 static auto close(int fd) {
309 return ::close(fd);
310 }
311};
312
320class TFileHandle: public TSocketBase<TFileOps> {
321public:
328 TFileHandle(int fd, TPollerBase& poller)
329 : TSocketBase(fd, poller)
330 { }
331
332 TFileHandle(TFileHandle&& other);
333 TFileHandle& operator=(TFileHandle&& other);
334
335 TFileHandle() = default;
336};
337
338class TSockOps {
339public:
340 static auto read(int fd, void* buf, size_t count) {
341 return ::recv(fd, static_cast<char*>(buf), count, 0);
342 }
343
344 static auto write(int fd, const void* buf, size_t count) {
345 return ::send(fd, static_cast<const char*>(buf), count, 0);
346 }
347
348 static auto close(int fd) {
349 if (fd >= 0) {
350#ifdef _WIN32
351 ::closesocket(fd);
352#else
353 ::close(fd);
354#endif
355 }
356 }
357};
358
367class TSocket: public TSocketBase<TSockOps> {
368public:
369 using TPoller = TPollerBase;
370
371 TSocket() = default;
372
380 TSocket(TPollerBase& poller, int domain, int type = SOCK_STREAM);
388 TSocket(const TAddress& addr, int fd, TPollerBase& poller);
389
390 TSocket(TSocket&& other);
391 TSocket& operator=(TSocket&& other);
392
405 auto Connect(const TAddress& addr, TTime deadline = TTime::max()) {
406 if (RemoteAddr_.has_value()) {
407 throw std::runtime_error("Already connected");
408 }
409 RemoteAddr_ = addr;
410 struct TAwaitable {
411 bool await_ready() {
412 int ret = connect(fd, addr.first, addr.second);
413#ifdef _WIN32
414 if (ret < 0 && WSAGetLastError() != WSAEWOULDBLOCK) {
415 throw std::system_error(WSAGetLastError(), std::generic_category(), "connect");
416 }
417#else
418 if (ret < 0 && !(errno == EINTR||errno==EAGAIN||errno==EINPROGRESS)) {
419 throw std::system_error(errno, std::generic_category(), "connect");
420 }
421#endif
422 return ret >= 0;
423 }
424
425 void await_suspend(std::coroutine_handle<> h) {
426 poller->AddWrite(fd, h);
427 if (deadline != TTime::max()) {
428 timerId = poller->AddTimer(deadline, h);
429 }
430 }
431
432 void await_resume() {
433 if (deadline != TTime::max() && poller->RemoveTimer(timerId, deadline)) {
434 throw std::system_error(std::make_error_code(std::errc::timed_out));
435 }
436 }
437
438 TPollerBase* poller;
439 int fd;
440 std::pair<const sockaddr*, int> addr;
441 TTime deadline;
442 unsigned timerId = 0;
443 };
444 return TAwaitable{Poller_, Fd_, RemoteAddr_->RawAddr(), deadline};
445 }
456 auto Accept() {
457 struct TAwaitable {
458 bool await_ready() const { return false; }
459 void await_suspend(std::coroutine_handle<> h) {
460 poller->AddRead(fd, h);
461 }
462 TSocket await_resume() {
463 char clientaddr[sizeof(sockaddr_in6)];
464 socklen_t len = static_cast<socklen_t>(sizeof(sockaddr_in6));
465
466 int clientfd = accept(fd, reinterpret_cast<sockaddr*>(&clientaddr[0]), &len);
467 if (clientfd < 0) {
468 throw std::system_error(errno, std::generic_category(), "accept");
469 }
470
471 return TSocket{TAddress{reinterpret_cast<sockaddr*>(&clientaddr[0]), len}, clientfd, *poller};
472 }
473
474 TPollerBase* poller;
475 int fd;
476 };
477
478 return TAwaitable{Poller_, Fd_};
479 }
480
482 void Bind(const TAddress& addr);
484 void Listen(int backlog = 128);
490 const std::optional<TAddress>& RemoteAddr() const;
496 const std::optional<TAddress>& LocalAddr() const;
498 int Fd() const;
499
500protected:
501 std::optional<TAddress> LocalAddr_;
502 std::optional<TAddress> RemoteAddr_;
503};
504
520template<typename T>
522{
523public:
524 using TPoller = T;
525
535 TPollerDrivenSocket(T& poller, int domain, int type = SOCK_STREAM)
536 : TSocket(poller, domain, type)
537 , Poller_(&poller)
538 {
539 Poller_->Register(Fd_);
540 }
541
548 TPollerDrivenSocket(const TAddress& addr, int fd, T& poller)
549 : TSocket(addr, fd, poller)
550 , Poller_(&poller)
551 {
552 Poller_->Register(Fd_);
553 }
554
555 TPollerDrivenSocket(int fd, T& poller)
556 : TSocket({}, fd, poller)
557 , Poller_(&poller)
558 {
559 Poller_->Register(Fd_);
560 }
561
562 TPollerDrivenSocket() = default;
563
572 auto Accept() {
573 struct TAwaitable {
574 bool await_ready() const { return false; }
575 void await_suspend(std::coroutine_handle<> h) {
576 poller->Accept(fd, reinterpret_cast<sockaddr*>(&addr[0]), &len, h);
577 }
578
579 TPollerDrivenSocket<T> await_resume() {
580 int clientfd = poller->Result();
581 if (clientfd < 0) {
582 throw std::system_error(-clientfd, std::generic_category(), "accept");
583 }
584
585 return TPollerDrivenSocket<T>{TAddress{reinterpret_cast<sockaddr*>(&addr[0]), len}, clientfd, *poller};
586 }
587
588 T* poller;
589 int fd;
590
591 char addr[2*(sizeof(sockaddr_in6)+16)] = {0}; // use additional memory for windows
592 socklen_t len = static_cast<socklen_t>(sizeof(addr));
593 };
594
595 return TAwaitable{Poller_, Fd_};
596 }
597
610 auto Connect(const TAddress& addr, TTime deadline = TTime::max()) {
611 if (RemoteAddr_.has_value()) {
612 throw std::runtime_error("Already connected");
613 }
614 RemoteAddr_ = addr;
615 struct TAwaitable {
616 bool await_ready() const { return false; }
617
618 void await_suspend(std::coroutine_handle<> h) {
619 poller->Connect(fd, addr.first, addr.second, h);
620 if (deadline != TTime::max()) {
621 timerId = poller->AddTimer(deadline, h);
622 }
623 }
624
625 void await_resume() {
626 if (deadline != TTime::max() && poller->RemoveTimer(timerId, deadline)) {
627 poller->Cancel(fd);
628 throw std::system_error(std::make_error_code(std::errc::timed_out));
629 }
630 int ret = poller->Result();
631 if (ret < 0) {
632 throw std::system_error(-ret, std::generic_category(), "connect");
633 }
634 }
635
636 T* poller;
637 int fd;
638 std::pair<const sockaddr*, int> addr;
639 TTime deadline;
640 unsigned timerId = 0;
641 };
642 return TAwaitable{Poller_, Fd_, RemoteAddr()->RawAddr(), deadline};
643 }
644
655 auto ReadSome(void* buf, size_t size) {
656 struct TAwaitable {
657 bool await_ready() const { return false; }
658 void await_suspend(std::coroutine_handle<> h) {
659 poller->Recv(fd, buf, size, h);
660 }
661
662 auto await_resume() {
663 auto ret = poller->Result();
664 if (ret < 0) {
665#ifdef _WIN32
666 int err = -ret;
667 if (err == WSAEWOULDBLOCK || err == WSAEINTR || err == WSAEINPROGRESS) {
668 return ret; // retry hint
669 }
670#else
671 int err = -ret;
672 if (err == EINTR || err == EAGAIN || err == EINPROGRESS) {
673 return ret; // retry hint
674 }
675#endif
676 throw std::system_error(-ret, std::generic_category());
677 }
678 return ret;
679 }
680
681 T* poller;
682 int fd;
683
684 void* buf;
685 size_t size;
686 };
687
688 return TAwaitable{Poller_, Fd_, buf, size};
689 }
690
701 auto WriteSome(const void* buf, size_t size) {
702 struct TAwaitable {
703 bool await_ready() const { return false; }
704 void await_suspend(std::coroutine_handle<> h) {
705 poller->Send(fd, buf, size, h);
706 }
707
708 auto await_resume() {
709 auto ret = poller->Result();
710 if (ret < 0) {
711#ifdef _WIN32
712 int err = -ret;
713 if (err == WSAEWOULDBLOCK || err == WSAEINTR || err == WSAEINPROGRESS) {
714 return ret; // retry hint
715 }
716#else
717 int err = -ret;
718 if (err == EINTR || err == EAGAIN || err == EINPROGRESS) {
719 return ret; // retry hint
720 }
721#endif
722 throw std::system_error(-ret, std::generic_category());
723 }
724 return ret;
725 }
726
727 T* poller;
728 int fd;
729
730 const void* buf;
731 size_t size;
732 };
733
734 return TAwaitable{Poller_, Fd_, buf, size};
735 }
736
738 auto WriteSomeYield(const void* buf, size_t size) {
739 return WriteSome(buf, size);
740 }
741
743 auto ReadSomeYield(void* buf, size_t size) {
744 return ReadSome(buf, size);
745 }
746
747private:
748 T* Poller_;
749};
750
765template<typename T>
767{
768public:
769 using TPoller = T;
770
779 TPollerDrivenFileHandle(int fd, T& poller)
780 : TFileHandle(fd, poller)
781 , Poller_(&poller)
782 { }
783
794 auto ReadSome(void* buf, size_t size) {
795 struct TAwaitable {
796 bool await_ready() const { return false; }
797 void await_suspend(std::coroutine_handle<> h) {
798 poller->Read(fd, buf, size, h);
799 }
800
801 auto await_resume() {
802 auto ret = poller->Result();
803 if (ret < 0) {
804#ifdef _WIN32
805 int err = -ret;
806 if (err == WSAEWOULDBLOCK || err == WSAEINTR || err == WSAEINPROGRESS) {
807 return ret; // retry hint
808 }
809#else
810 int err = -ret;
811 if (err == EINTR || err == EAGAIN || err == EINPROGRESS) {
812 return ret; // retry hint
813 }
814#endif
815 throw std::system_error(-ret, std::generic_category());
816 }
817 return ret;
818 }
819
820 T* poller;
821 int fd;
822
823 void* buf;
824 size_t size;
825 };
826
827 return TAwaitable{Poller_, Fd_, buf, size};
828 }
829
840 auto WriteSome(const void* buf, size_t size) {
841 struct TAwaitable {
842 bool await_ready() const { return false; }
843 void await_suspend(std::coroutine_handle<> h) {
844 poller->Write(fd, buf, size, h);
845 }
846
847 auto await_resume() {
848 auto ret = poller->Result();
849 if (ret < 0) {
850#ifdef _WIN32
851 int err = -ret;
852 if (err == WSAEWOULDBLOCK || err == WSAEINTR || err == WSAEINPROGRESS) {
853 return ret; // retry hint
854 }
855#else
856 int err = -ret;
857 if (err == EINTR || err == EAGAIN || err == EINPROGRESS) {
858 return ret; // retry hint
859 }
860#endif
861 throw std::system_error(-ret, std::generic_category());
862 }
863 return ret;
864 }
865
866 T* poller;
867 int fd;
868
869 const void* buf;
870 size_t size;
871 };
872
873 return TAwaitable{Poller_, Fd_, buf, size};
874 }
875
877 auto WriteSomeYield(const void* buf, size_t size) {
878 return WriteSome(buf, size);
879 }
880
882 auto ReadSomeYield(void* buf, size_t size) {
883 return ReadSome(buf, size);
884 }
885
886private:
887 T* Poller_;
888};
889
890} // namespace NNet
A class representing an IPv4 or IPv6 address (with port).
Definition address.hpp:38
Asynchronous file handle that owns its file descriptor.
Definition socket.hpp:320
TFileHandle(int fd, TPollerBase &poller)
Constructs a TFileHandle from an existing file descriptor.
Definition socket.hpp:328
Definition socket.hpp:298
Base class for pollers managing asynchronous I/O events and timers.
Definition poller.hpp:52
unsigned AddTimer(TTime deadline, THandle h)
Schedules a timer.
Definition poller.hpp:70
void AddWrite(int fd, THandle h)
Registers a write event on a file descriptor.
Definition poller.hpp:109
bool RemoveTimer(unsigned timerId, TTime deadline)
Removes or cancels a timer.
Definition poller.hpp:85
void AddRead(int fd, THandle h)
Registers a read event on a file descriptor.
Definition poller.hpp:99
Asynchronous file handle driven by the poller's implementation.
Definition socket.hpp:767
auto WriteSomeYield(const void *buf, size_t size)
The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
Definition socket.hpp:877
auto ReadSomeYield(void *buf, size_t size)
The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
Definition socket.hpp:882
auto WriteSome(const void *buf, size_t size)
Asynchronously writes data from the provided buffer to the file.
Definition socket.hpp:840
TPollerDrivenFileHandle(int fd, T &poller)
Constructs a TPollerDrivenFileHandle from an existing file descriptor.
Definition socket.hpp:779
auto ReadSome(void *buf, size_t size)
Asynchronously reads data from the file into the provided buffer.
Definition socket.hpp:794
Socket type driven by the poller's implementation.
Definition socket.hpp:522
auto ReadSomeYield(void *buf, size_t size)
The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
Definition socket.hpp:743
auto ReadSome(void *buf, size_t size)
Asynchronously reads data from the socket.
Definition socket.hpp:655
auto WriteSome(const void *buf, size_t size)
Asynchronously writes data to the socket.
Definition socket.hpp:701
auto Accept()
Asynchronously accepts an incoming connection.
Definition socket.hpp:572
TPollerDrivenSocket(T &poller, int domain, int type=SOCK_STREAM)
Constructs a TPollerDrivenSocket from a poller, domain, and socket type.
Definition socket.hpp:535
auto Connect(const TAddress &addr, TTime deadline=TTime::max())
Asynchronously connects to the specified address with an optional deadline.
Definition socket.hpp:610
auto WriteSomeYield(const void *buf, size_t size)
The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
Definition socket.hpp:738
TPollerDrivenSocket(const TAddress &addr, int fd, T &poller)
Constructs a TPollerDrivenSocket from an existing file descriptor.
Definition socket.hpp:548
Definition socket.hpp:338
TPollerBase * Poller()
Returns the poller associated with this socket.
Definition socket.hpp:47
Template base class implementing asynchronous socket I/O operations.
Definition socket.hpp:99
TSocketBase(TPollerBase &poller, int domain, int type)
Constructs a TSocketBase with a new socket descriptor.
Definition socket.hpp:108
auto Monitor()
Monitors the socket for remote hang-up (closure).
Definition socket.hpp:235
auto WriteSomeYield(const void *buf, size_t size)
Forces a write operation on the next event loop iteration.
Definition socket.hpp:211
auto WriteSome(const void *buf, size_t size)
Asynchronously writes data from the provided buffer to the socket.
Definition socket.hpp:189
auto ReadSome(void *buf, size_t size)
Asynchronously reads data from the socket into the provided buffer.
Definition socket.hpp:139
void Close()
Closes the socket.
Definition socket.hpp:253
TSocketBase(int fd, TPollerBase &poller)
Constructs a TSocketBase from an existing socket descriptor.
Definition socket.hpp:116
auto ReadSomeYield(void *buf, size_t size)
Forces a read operation on the next event loop iteration.
Definition socket.hpp:162
High-level asynchronous socket for network communication.
Definition socket.hpp:367
const std::optional< TAddress > & RemoteAddr() const
Returns the remote address of the connected peer.
Definition socket.cpp:109
auto Connect(const TAddress &addr, TTime deadline=TTime::max())
Asynchronously connects to the specified address.
Definition socket.hpp:405
int Fd() const
Returns the underlying socket descriptor.
Definition socket.cpp:113
void Bind(const TAddress &addr)
Binds the socket to the specified local address.
Definition socket.cpp:83
auto Accept()
Asynchronously accepts an incoming connection.
Definition socket.hpp:456
const std::optional< TAddress > & LocalAddr() const
Returns the local address to which the socket is bound.
Definition socket.cpp:105
void Listen(int backlog=128)
Puts the socket in a listening state with an optional backlog (default is 128).
Definition socket.cpp:99
Definition socket.hpp:264