COROIO: coroio/socket.hpp Source File
COROIO
 
Loading...
Searching...
No Matches
socket.hpp
1#pragma once
2
3#ifdef _WIN32
4#include <winsock2.h>
5#include <ws2tcpip.h>
6#include <io.h>
7#else
8#include <arpa/inet.h>
9#include <sys/stat.h>
10#include <sys/socket.h>
11#include <sys/types.h>
12#include <netinet/in.h>
13#include <unistd.h>
14#include <fcntl.h>
15#endif
16
17#include <optional>
18#include <variant>
19
20#include "poller.hpp"
21#include "address.hpp"
22
23namespace NNet {
24
25template<typename T> class TSocketBase;
26
40template<>
41class TSocketBase<void> {
42public:
47 TPollerBase* Poller() { return Poller_; }
48
49protected:
50 TSocketBase(TPollerBase& poller, int domain, int type);
51 TSocketBase(int fd, TPollerBase& poller);
52 TSocketBase() = default;
53
61 int Create(int domain, int type);
70 int Setup(int s);
71
72 TPollerBase* Poller_ = nullptr;
73 int Fd_ = -1;
74};
75
98template<typename TSockOps>
99class TSocketBase: public TSocketBase<void> {
100protected:
108 TSocketBase(TPollerBase& poller, int domain, int type): TSocketBase<void>(poller, domain, type)
109 { }
110
116 TSocketBase(int fd, TPollerBase& poller): TSocketBase<void>(fd, poller)
117 { }
118
119 TSocketBase() = default;
120 TSocketBase(const TSocketBase& other) = delete;
121 TSocketBase& operator=(TSocketBase& other) const = delete;
122
123 ~TSocketBase() {
124 Close();
125 }
126
127public:
138 auto ReadSome(void* buf, size_t size) {
139 struct TAwaitableRead: public TAwaitable<TAwaitableRead> {
140 void run() {
141 this->ret = TSockOps::read(this->fd, this->b, this->s);
142 }
143
144 void await_suspend(std::coroutine_handle<> h) {
145 this->poller->AddRead(this->fd, h);
146 }
147 };
148 return TAwaitableRead{Poller_,Fd_,buf,size};
149 }
150
160 auto ReadSomeYield(void* buf, size_t size) {
161 struct TAwaitableRead: public TAwaitable<TAwaitableRead> {
162 bool await_ready() {
163 return (this->ready = false);
164 }
165
166 void run() {
167 this->ret = TSockOps::read(this->fd, this->b, this->s);
168 }
169
170 void await_suspend(std::coroutine_handle<> h) {
171 this->poller->AddRead(this->fd, h);
172 }
173 };
174 return TAwaitableRead{Poller_,Fd_,buf,size};
175 }
176
186 auto WriteSome(const void* buf, size_t size) {
187 struct TAwaitableWrite: public TAwaitable<TAwaitableWrite> {
188 void run() {
189 this->ret = TSockOps::write(this->fd, this->b, this->s);
190 }
191
192 void await_suspend(std::coroutine_handle<> h) {
193 this->poller->AddWrite(this->fd, h);
194 }
195 };
196 return TAwaitableWrite{Poller_,Fd_,const_cast<void*>(buf),size};
197 }
198
208 auto WriteSomeYield(const void* buf, size_t size) {
209 struct TAwaitableWrite: public TAwaitable<TAwaitableWrite> {
210 bool await_ready() {
211 return (this->ready = false);
212 }
213
214 void run() {
215 this->ret = TSockOps::write(this->fd, this->b, this->s);
216 }
217
218 void await_suspend(std::coroutine_handle<> h) {
219 this->poller->AddWrite(this->fd, h);
220 }
221 };
222 return TAwaitableWrite{Poller_,Fd_,const_cast<void*>(buf),size};
223 }
224
232 auto Monitor() {
233 struct TAwaitableClose: public TAwaitable<TAwaitableClose> {
234 void run() {
235 this->ret = true;
236 }
237
238 void await_suspend(std::coroutine_handle<> h) {
239 this->poller->AddRemoteHup(this->fd, h);
240 }
241 };
242 return TAwaitableClose{Poller_,Fd_};
243 }
244
250 void Close()
251 {
252 if (Fd_ >= 0) {
253 TSockOps::close(Fd_);
254 Poller_->RemoveEvent(Fd_);
255 Fd_ = -1;
256 }
257 }
258
259protected:
260 template<typename T>
261 struct TAwaitable {
262 bool await_ready() {
263 SafeRun();
264 return (ready = (ret >= 0));
265 }
266
267 int await_resume() {
268 if (!ready) {
269 SafeRun();
270 }
271 return ret;
272 }
273
274 void SafeRun() {
275 ((T*)this)->run();
276#ifdef _WIN32
277 if (ret < 0 && WSAGetLastError() != WSAEWOULDBLOCK ) {
278 throw std::system_error(WSAGetLastError(), std::generic_category());
279 }
280#else
281 if (ret < 0 && !(errno==EINTR||errno==EAGAIN||errno==EINPROGRESS)) {
282 throw std::system_error(errno, std::generic_category());
283 }
284#endif
285 }
286
287 TPollerBase* poller = nullptr;
288 int fd = -1;
289 void* b = nullptr; size_t s = 0;
290 int ret = -1;
291 bool ready = false;
292 };
293};
294
295class TFileOps {
296public:
297 static auto read(int fd, void* buf, size_t count) {
298 return ::read(fd, buf, count);
299 }
300
301 static auto write(int fd, const void* buf, size_t count) {
302 return ::write(fd, buf, count);
303 }
304
305 static auto close(int fd) {
306 return ::close(fd);
307 }
308};
309
317class TFileHandle: public TSocketBase<TFileOps> {
318public:
325 TFileHandle(int fd, TPollerBase& poller)
326 : TSocketBase(fd, poller)
327 { }
328
329 TFileHandle(TFileHandle&& other);
330 TFileHandle& operator=(TFileHandle&& other);
331
332 TFileHandle() = default;
333};
334
335class TSockOps {
336public:
337 static auto read(int fd, void* buf, size_t count) {
338 return ::recv(fd, static_cast<char*>(buf), count, 0);
339 }
340
341 static auto write(int fd, const void* buf, size_t count) {
342 return ::send(fd, static_cast<const char*>(buf), count, 0);
343 }
344
345 static auto close(int fd) {
346 if (fd >= 0) {
347#ifdef _WIN32
348 ::closesocket(fd);
349#else
350 ::close(fd);
351#endif
352 }
353 }
354};
355
364class TSocket: public TSocketBase<TSockOps> {
365public:
366 using TPoller = TPollerBase;
367
368 TSocket() = default;
369
377 TSocket(TPollerBase& poller, int domain, int type = SOCK_STREAM);
385 TSocket(const TAddress& addr, int fd, TPollerBase& poller);
386
387 TSocket(TSocket&& other);
388 TSocket& operator=(TSocket&& other);
389
402 auto Connect(const TAddress& addr, TTime deadline = TTime::max()) {
403 if (RemoteAddr_.has_value()) {
404 throw std::runtime_error("Already connected");
405 }
406 RemoteAddr_ = addr;
407 struct TAwaitable {
408 bool await_ready() {
409 int ret = connect(fd, addr.first, addr.second);
410#ifdef _WIN32
411 if (ret < 0 && WSAGetLastError() != WSAEWOULDBLOCK) {
412 throw std::system_error(WSAGetLastError(), std::generic_category(), "connect");
413 }
414#else
415 if (ret < 0 && !(errno == EINTR||errno==EAGAIN||errno==EINPROGRESS)) {
416 throw std::system_error(errno, std::generic_category(), "connect");
417 }
418#endif
419 return ret >= 0;
420 }
421
422 void await_suspend(std::coroutine_handle<> h) {
423 poller->AddWrite(fd, h);
424 if (deadline != TTime::max()) {
425 timerId = poller->AddTimer(deadline, h);
426 }
427 }
428
429 void await_resume() {
430 if (deadline != TTime::max() && poller->RemoveTimer(timerId, deadline)) {
431 throw std::system_error(std::make_error_code(std::errc::timed_out));
432 }
433 }
434
435 TPollerBase* poller;
436 int fd;
437 std::pair<const sockaddr*, int> addr;
438 TTime deadline;
439 unsigned timerId = 0;
440 };
441 return TAwaitable{Poller_, Fd_, RemoteAddr_->RawAddr(), deadline};
442 }
443
453 auto Accept() {
454 struct TAwaitable {
455 bool await_ready() const { return false; }
456 void await_suspend(std::coroutine_handle<> h) {
457 poller->AddRead(fd, h);
458 }
459 TSocket await_resume() {
460 char clientaddr[sizeof(sockaddr_in6)];
461 socklen_t len = sizeof(sockaddr_in6);
462
463 int clientfd = accept(fd, reinterpret_cast<sockaddr*>(&clientaddr[0]), &len);
464 if (clientfd < 0) {
465 throw std::system_error(errno, std::generic_category(), "accept");
466 }
467
468 return TSocket{TAddress{reinterpret_cast<sockaddr*>(&clientaddr[0]), len}, clientfd, *poller};
469 }
470
471 TPollerBase* poller;
472 int fd;
473 };
474
475 return TAwaitable{Poller_, Fd_};
476 }
477
479 void Bind(const TAddress& addr);
481 void Listen(int backlog = 128);
487 const std::optional<TAddress>& RemoteAddr() const;
493 const std::optional<TAddress>& LocalAddr() const;
495 int Fd() const;
496
497protected:
498 std::optional<TAddress> LocalAddr_;
499 std::optional<TAddress> RemoteAddr_;
500};
501
517template<typename T>
518class TPollerDrivenSocket: public TSocket
519{
520public:
521 using TPoller = T;
522
532 TPollerDrivenSocket(T& poller, int domain, int type = SOCK_STREAM)
533 : TSocket(poller, domain, type)
534 , Poller_(&poller)
535 {
536 Poller_->Register(Fd_);
537 }
538
545 TPollerDrivenSocket(const TAddress& addr, int fd, T& poller)
546 : TSocket(addr, fd, poller)
547 , Poller_(&poller)
548 {
549 Poller_->Register(Fd_);
550 }
551
552 TPollerDrivenSocket(int fd, T& poller)
553 : TSocket({}, fd, poller)
554 , Poller_(&poller)
555 {
556 Poller_->Register(Fd_);
557 }
558
559 TPollerDrivenSocket() = default;
560
569 auto Accept() {
570 struct TAwaitable {
571 bool await_ready() const { return false; }
572 void await_suspend(std::coroutine_handle<> h) {
573 poller->Accept(fd, reinterpret_cast<sockaddr*>(&addr[0]), &len, h);
574 }
575
576 TPollerDrivenSocket<T> await_resume() {
577 int clientfd = poller->Result();
578 if (clientfd < 0) {
579 throw std::system_error(-clientfd, std::generic_category(), "accept");
580 }
581
582 return TPollerDrivenSocket<T>{TAddress{reinterpret_cast<sockaddr*>(&addr[0]), len}, clientfd, *poller};
583 }
584
585 T* poller;
586 int fd;
587
588 char addr[2*(sizeof(sockaddr_in6)+16)] = {0}; // use additional memory for windows
589 socklen_t len = sizeof(addr);
590 };
591
592 return TAwaitable{Poller_, Fd_};
593 }
594
607 auto Connect(const TAddress& addr, TTime deadline = TTime::max()) {
608 if (RemoteAddr_.has_value()) {
609 throw std::runtime_error("Already connected");
610 }
611 RemoteAddr_ = addr;
612 struct TAwaitable {
613 bool await_ready() const { return false; }
614
615 void await_suspend(std::coroutine_handle<> h) {
616 poller->Connect(fd, addr.first, addr.second, h);
617 if (deadline != TTime::max()) {
618 timerId = poller->AddTimer(deadline, h);
619 }
620 }
621
622 void await_resume() {
623 if (deadline != TTime::max() && poller->RemoveTimer(timerId, deadline)) {
624 poller->Cancel(fd);
625 throw std::system_error(std::make_error_code(std::errc::timed_out));
626 }
627 int ret = poller->Result();
628 if (ret < 0) {
629 throw std::system_error(-ret, std::generic_category(), "connect");
630 }
631 }
632
633 T* poller;
634 int fd;
635 std::pair<const sockaddr*, int> addr;
636 TTime deadline;
637 unsigned timerId = 0;
638 };
639 return TAwaitable{Poller_, Fd_, RemoteAddr()->RawAddr(), deadline};
640 }
641
652 auto ReadSome(void* buf, size_t size) {
653 struct TAwaitable {
654 bool await_ready() const { return false; }
655 void await_suspend(std::coroutine_handle<> h) {
656 poller->Recv(fd, buf, size, h);
657 }
658
659 ssize_t await_resume() {
660 int ret = poller->Result();
661 if (ret < 0) {
662 throw std::system_error(-ret, std::generic_category());
663 }
664 return ret;
665 }
666
667 T* poller;
668 int fd;
669
670 void* buf;
671 size_t size;
672 };
673
674 return TAwaitable{Poller_, Fd_, buf, size};
675 }
676
687 auto WriteSome(const void* buf, size_t size) {
688 struct TAwaitable {
689 bool await_ready() const { return false; }
690 void await_suspend(std::coroutine_handle<> h) {
691 poller->Send(fd, buf, size, h);
692 }
693
694 ssize_t await_resume() {
695 int ret = poller->Result();
696 if (ret < 0) {
697 throw std::system_error(-ret, std::generic_category());
698 }
699 return ret;
700 }
701
702 T* poller;
703 int fd;
704
705 const void* buf;
706 size_t size;
707 };
708
709 return TAwaitable{Poller_, Fd_, buf, size};
710 }
711
713 auto WriteSomeYield(const void* buf, size_t size) {
714 return WriteSome(buf, size);
715 }
716
718 auto ReadSomeYield(void* buf, size_t size) {
719 return ReadSome(buf, size);
720 }
721
722private:
723 T* Poller_;
724};
725
740template<typename T>
742{
743public:
744 using TPoller = T;
745
754 TPollerDrivenFileHandle(int fd, T& poller)
755 : TFileHandle(fd, poller)
756 , Poller_(&poller)
757 { }
758
769 auto ReadSome(void* buf, size_t size) {
770 struct TAwaitable {
771 bool await_ready() const { return false; }
772 void await_suspend(std::coroutine_handle<> h) {
773 poller->Read(fd, buf, size, h);
774 }
775
776 ssize_t await_resume() {
777 int ret = poller->Result();
778 if (ret < 0) {
779 throw std::system_error(-ret, std::generic_category());
780 }
781 return ret;
782 }
783
784 T* poller;
785 int fd;
786
787 void* buf;
788 size_t size;
789 };
790
791 return TAwaitable{Poller_, Fd_, buf, size};
792 }
793
804 auto WriteSome(const void* buf, size_t size) {
805 struct TAwaitable {
806 bool await_ready() const { return false; }
807 void await_suspend(std::coroutine_handle<> h) {
808 poller->Write(fd, buf, size, h);
809 }
810
811 ssize_t await_resume() {
812 int ret = poller->Result();
813 if (ret < 0) {
814 throw std::system_error(-ret, std::generic_category());
815 }
816 return ret;
817 }
818
819 T* poller;
820 int fd;
821
822 const void* buf;
823 size_t size;
824 };
825
826 return TAwaitable{Poller_, Fd_, buf, size};
827 }
828
830 auto WriteSomeYield(const void* buf, size_t size) {
831 return WriteSome(buf, size);
832 }
833
835 auto ReadSomeYield(void* buf, size_t size) {
836 return ReadSome(buf, size);
837 }
838
839private:
840 T* Poller_;
841};
842
843} // namespace NNet
A class representing an IPv4 or IPv6 address (with port).
Definition address.hpp:30
Asynchronous file handle that owns its file descriptor.
Definition socket.hpp:317
TFileHandle(int fd, TPollerBase &poller)
Constructs a TFileHandle from an existing file descriptor.
Definition socket.hpp:325
Definition socket.hpp:295
Base class for pollers managing asynchronous I/O events and timers.
Definition poller.hpp:44
auto WriteSomeYield(const void *buf, size_t size)
The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
Definition socket.hpp:830
auto ReadSomeYield(void *buf, size_t size)
The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
Definition socket.hpp:835
auto WriteSome(const void *buf, size_t size)
Asynchronously writes data from the provided buffer to the file.
Definition socket.hpp:804
TPollerDrivenFileHandle(int fd, T &poller)
Constructs a TPollerDrivenFileHandle from an existing file descriptor.
Definition socket.hpp:754
auto ReadSome(void *buf, size_t size)
Asynchronously reads data from the file into the provided buffer.
Definition socket.hpp:769
Socket type driven by the poller's implementation.
Definition socket.hpp:519
auto ReadSomeYield(void *buf, size_t size)
The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
Definition socket.hpp:718
auto ReadSome(void *buf, size_t size)
Asynchronously reads data from the socket.
Definition socket.hpp:652
auto WriteSome(const void *buf, size_t size)
Asynchronously writes data to the socket.
Definition socket.hpp:687
auto Accept()
Asynchronously accepts an incoming connection.
Definition socket.hpp:569
TPollerDrivenSocket(T &poller, int domain, int type=SOCK_STREAM)
Constructs a TPollerDrivenSocket from a poller, domain, and socket type.
Definition socket.hpp:532
auto Connect(const TAddress &addr, TTime deadline=TTime::max())
Asynchronously connects to the specified address with an optional deadline.
Definition socket.hpp:607
auto WriteSomeYield(const void *buf, size_t size)
The WriteSomeYield and ReadSomeYield variants behave similarly to WriteSome/ReadSome.
Definition socket.hpp:713
TPollerDrivenSocket(const TAddress &addr, int fd, T &poller)
Constructs a TPollerDrivenSocket from an existing file descriptor.
Definition socket.hpp:545
Definition socket.hpp:335
TPollerBase * Poller()
Returns the poller associated with this socket.
Definition socket.hpp:47
int Setup(int s)
Performs additional setup on the socket descriptor.
int Create(int domain, int type)
Creates a new socket descriptor.
Template base class implementing asynchronous socket I/O operations.
Definition socket.hpp:99
TSocketBase(TPollerBase &poller, int domain, int type)
Constructs a TSocketBase with a new socket descriptor.
Definition socket.hpp:108
auto Monitor()
Monitors the socket for remote hang-up (closure).
Definition socket.hpp:232
auto WriteSomeYield(const void *buf, size_t size)
Forces a write operation on the next event loop iteration.
Definition socket.hpp:208
auto WriteSome(const void *buf, size_t size)
Asynchronously writes data from the provided buffer to the socket.
Definition socket.hpp:186
auto ReadSome(void *buf, size_t size)
Asynchronously reads data from the socket into the provided buffer.
Definition socket.hpp:138
void Close()
Closes the socket.
Definition socket.hpp:250
TSocketBase(int fd, TPollerBase &poller)
Constructs a TSocketBase from an existing socket descriptor.
Definition socket.hpp:116
auto ReadSomeYield(void *buf, size_t size)
Forces a read operation on the next event loop iteration.
Definition socket.hpp:160
High-level asynchronous socket for network communication.
Definition socket.hpp:364
TSocket(TPollerBase &poller, int domain, int type=SOCK_STREAM)
Constructs a TSocket using the given poller, address family, and socket type.
const std::optional< TAddress > & RemoteAddr() const
Returns the remote address of the connected peer.
auto Connect(const TAddress &addr, TTime deadline=TTime::max())
Asynchronously connects to the specified address.
Definition socket.hpp:402
int Fd() const
Returns the underlying socket descriptor.
TSocket(const TAddress &addr, int fd, TPollerBase &poller)
Constructs a TSocket for an already-connected socket.
void Bind(const TAddress &addr)
Binds the socket to the specified local address.
auto Accept()
Asynchronously accepts an incoming connection.
Definition socket.hpp:453
const std::optional< TAddress > & LocalAddr() const
Returns the local address to which the socket is bound.
void Listen(int backlog=128)
Puts the socket in a listening state with an optional backlog (default is 128).
Definition socket.hpp:261