diff --git a/Makefile b/Makefile index 14d4658..593f922 100644 --- a/Makefile +++ b/Makefile @@ -114,7 +114,7 @@ EXTRA_OBJS = $(TARGETDIR)/md_linux.o $(TARGETDIR)/md_linux2.o SFLAGS = -fPIC LDFLAGS = -shared -soname=$(SONAME) -lc OTHER_FLAGS = -Wall -DEFINES += -DMD_HAVE_EPOLL -DMD_HAVE_SELECT +DEFINES += -DMD_HAVE_EPOLL -DMD_HAVE_SELECT -DMD_HAVE_IO_URING endif ifeq ($(OS), CYGWIN64) @@ -310,9 +310,20 @@ darwin-optimized: $(MAKE) OS="DARWIN" BUILD="OPT" linux-debug: - $(MAKE) OS="LINUX" BUILD="DBG" + $(MAKE) clean + $(MAKE) CFLAGS="-g -O0 -Wall -Werror -DDEBUG" \ + DEFINES="-DMD_HAVE_EPOLL -DMD_HAVE_SELECT -DMD_HAVE_IO_URING" \ + LDFLAGS="-L/usr/local/lib -L/usr/lib" \ + LIBS="-luring" \ + all + linux-optimized: - $(MAKE) OS="LINUX" BUILD="OPT" + $(MAKE) clean + $(MAKE) CFLAGS="-O2 -Wall -Werror" \ + DEFINES="-DMD_HAVE_EPOLL -DMD_HAVE_SELECT -DMD_HAVE_IO_URING" \ + LDFLAGS="-L/usr/local/lib -L/usr/lib" \ + LIBS="-luring" \ + all cygwin64-debug: $(MAKE) OS="CYGWIN64" BUILD="DBG" diff --git a/event.c b/event.c index 00a951c..6af5f8e 100644 --- a/event.c +++ b/event.c @@ -40,6 +40,7 @@ #include #include #include +#include #include "common.h" #ifdef MD_HAVE_KQUEUE @@ -48,6 +49,10 @@ #ifdef MD_HAVE_EPOLL #include #endif +#ifdef MD_HAVE_IO_URING +#include +#include +#endif // Global stat. #if defined(DEBUG) && defined(DEBUG_STATS) @@ -55,10 +60,13 @@ __thread unsigned long long _st_stat_epoll = 0; __thread unsigned long long _st_stat_epoll_zero = 0; __thread unsigned long long _st_stat_epoll_shake = 0; __thread unsigned long long _st_stat_epoll_spin = 0; +__thread unsigned long long _st_stat_io_uring = 0; +__thread unsigned long long _st_stat_io_uring_zero = 0; +__thread unsigned long long _st_stat_io_uring_spin = 0; #endif -#if !defined(MD_HAVE_KQUEUE) && !defined(MD_HAVE_EPOLL) && !defined(MD_HAVE_SELECT) - #error Only support epoll(for Linux), kqueue(for Darwin) or select(for Cygwin) +#if !defined(MD_HAVE_KQUEUE) && !defined(MD_HAVE_EPOLL) && !defined(MD_HAVE_SELECT) && !defined(MD_HAVE_IO_URING) + #error Only support epoll(for Linux), kqueue(for Darwin), select(for Cygwin) or io_uring(for Linux) #endif @@ -147,6 +155,46 @@ static __thread struct _st_epolldata { #endif /* MD_HAVE_EPOLL */ +#ifdef MD_HAVE_IO_URING +typedef struct _io_uring_fd_data { + int rd_ref_cnt; + int wr_ref_cnt; + int ex_ref_cnt; + int revents; + struct io_uring_sqe *sqe; + struct io_uring_cqe *cqe; +} _io_uring_fd_data_t; + +static __thread struct _st_io_uringdata { + _io_uring_fd_data_t *fd_data; + struct io_uring *ring; + struct io_uring_sqe *sqes; + struct io_uring_cqe *cqes; + int fd_data_size; + int ring_size; + int ring_cnt; + int fd_hint; + int ringfd; + pid_t pid; + struct io_uring_params params; +} *_st_io_uring_data; + +#ifndef ST_IO_URING_RING_SIZE + #define ST_IO_URING_RING_SIZE 4096 +#endif + +#define _ST_IO_URING_READ_CNT(fd) (_st_io_uring_data->fd_data[fd].rd_ref_cnt) +#define _ST_IO_URING_WRITE_CNT(fd) (_st_io_uring_data->fd_data[fd].wr_ref_cnt) +#define _ST_IO_URING_EXCEP_CNT(fd) (_st_io_uring_data->fd_data[fd].ex_ref_cnt) +#define _ST_IO_URING_REVENTS(fd) (_st_io_uring_data->fd_data[fd].revents) + +#define _ST_IO_URING_READ_BIT(fd) (_ST_IO_URING_READ_CNT(fd) ? IORING_POLL_ADD_MULTI : 0) +#define _ST_IO_URING_WRITE_BIT(fd) (_ST_IO_URING_WRITE_CNT(fd) ? IORING_POLL_ADD_MULTI : 0) +#define _ST_IO_URING_EXCEP_BIT(fd) (_ST_IO_URING_EXCEP_CNT(fd) ? IORING_POLL_ADD_MULTI : 0) +#define _ST_IO_URING_EVENTS(fd) \ + (_ST_IO_URING_READ_BIT(fd)|_ST_IO_URING_WRITE_BIT(fd)|_ST_IO_URING_EXCEP_BIT(fd)) +#endif /* MD_HAVE_IO_URING */ + __thread _st_eventsys_t *_st_eventsys = NULL; @@ -653,7 +701,7 @@ ST_HIDDEN void _st_kq_pollset_del(struct pollfd *pds, int npds) /* * It's OK if deleting fails because a descriptor will either be - * closed or fire only once (we set EV_ONESHOT flag). + * closed or deleted in dispatch function after it fires. */ _st_kq_data->dellist_cnt = 0; for (pd = pds; pd < epd; pd++) { @@ -758,28 +806,11 @@ ST_HIDDEN void _st_kq_dispatch(void) if (notify) { ST_REMOVE_LINK(&pq->links); pq->on_ioq = 0; - for (pds = pq->pds; pds < epds; pds++) { - osfd = pds->fd; - events = pds->events; - /* - * We set EV_ONESHOT flag so we only need to delete - * descriptor if it didn't fire. - */ - if ((events & POLLIN) && (--_ST_KQ_READ_CNT(osfd) == 0) && ((_ST_KQ_REVENTS(osfd) & POLLIN) == 0)) { - memset(&kev, 0, sizeof(kev)); - kev.ident = osfd; - kev.filter = EVFILT_READ; - kev.flags = EV_DELETE; - _st_kq_dellist_add(&kev); - } - if ((events & POLLOUT) && (--_ST_KQ_WRITE_CNT(osfd) == 0) && ((_ST_KQ_REVENTS(osfd) & POLLOUT) == 0)) { - memset(&kev, 0, sizeof(kev)); - kev.ident = osfd; - kev.filter = EVFILT_WRITE; - kev.flags = EV_DELETE; - _st_kq_dellist_add(&kev); - } - } + /* + * Here we will only delete/modify descriptors that + * didn't fire (see comments in _st_kq_pollset_del()). + */ + _st_kq_pollset_del(pq->pds, pq->npds); if (pq->thread->flags & _ST_FL_ON_SLEEPQ) _ST_DEL_SLEEPQ(pq->thread); @@ -788,17 +819,17 @@ ST_HIDDEN void _st_kq_dispatch(void) } } - if (_st_kq_data->dellist_cnt > 0) { - int rv; - do { - /* This kevent() won't block since result list size is 0 */ - rv = kevent(_st_kq_data->kq, _st_kq_data->dellist, _st_kq_data->dellist_cnt, NULL, 0, NULL); - } while (rv < 0 && errno == EINTR); - } - for (i = 0; i < nfd; i++) { + /* Delete/modify descriptors that fired */ osfd = _st_kq_data->evtlist[i].ident; _ST_KQ_REVENTS(osfd) = 0; + events = _ST_KQ_EVENTS(osfd); + op = events ? EPOLL_CTL_MOD : EPOLL_CTL_DEL; + kev.events = events; + kev.data.fd = osfd; + if (kevent(_st_kq_data->kq, &kev, 1, NULL, 0, NULL) == 0 && op == EPOLL_CTL_DEL) { + _st_kq_data->evtlist_cnt--; + } } } else if (nfd < 0) { if (errno == EBADF && _st_kq_data->pid != getpid()) { @@ -1122,13 +1153,6 @@ ST_HIDDEN void _st_epoll_dispatch(void) revents |= POLLIN; if ((events & POLLOUT) && (_ST_EPOLL_REVENTS(osfd) & EPOLLOUT)) revents |= POLLOUT; - if ((events & POLLPRI) && (_ST_EPOLL_REVENTS(osfd) & EPOLLPRI)) - revents |= POLLPRI; - if (_ST_EPOLL_REVENTS(osfd) & EPOLLERR) - revents |= POLLERR; - if (_ST_EPOLL_REVENTS(osfd) & EPOLLHUP) - revents |= POLLHUP; - pds->revents = revents; if (revents) { notify = 1; @@ -1230,6 +1254,459 @@ static _st_eventsys_t _st_epoll_eventsys = { #endif /* MD_HAVE_EPOLL */ +#ifdef MD_HAVE_IO_URING +/***************************************** + * io_uring event system + */ +ST_HIDDEN int _st_io_uring_init(void) +{ + int err = 0; + int rv = 0; + + _st_io_uring_data = (struct _st_io_uringdata *) calloc(1, sizeof(*_st_io_uring_data)); + if (!_st_io_uring_data) + return -1; + + _st_io_uring_data->ring_size = ST_IO_URING_RING_SIZE; + memset(&_st_io_uring_data->params, 0, sizeof(_st_io_uring_data->params)); + if ((_st_io_uring_data->ringfd = io_uring_setup(_st_io_uring_data->ring_size, &_st_io_uring_data->params)) < 0) { + err = errno; + rv = -1; + goto cleanup_io_uring; + } + fcntl(_st_io_uring_data->ringfd, F_SETFD, FD_CLOEXEC); + _st_io_uring_data->pid = getpid(); + + /* Map the ring */ + _st_io_uring_data->ring = (struct io_uring *)mmap(NULL, _st_io_uring_data->params.sq_off.array + _st_io_uring_data->params.sq_entries * sizeof(unsigned), + PROT_READ | PROT_WRITE, + MAP_SHARED | MAP_POPULATE, + _st_io_uring_data->ringfd, + IORING_OFF_SQ_RING); + if (_st_io_uring_data->ring == MAP_FAILED) { + err = errno; + rv = -1; + goto cleanup_io_uring; + } + + /* Map the submission queue */ + _st_io_uring_data->sqes = (struct io_uring_sqe *)mmap(NULL, _st_io_uring_data->params.sq_entries * sizeof(struct io_uring_sqe), + PROT_READ | PROT_WRITE, + MAP_SHARED | MAP_POPULATE, + _st_io_uring_data->ringfd, + IORING_OFF_SQES); + if (_st_io_uring_data->sqes == MAP_FAILED) { + err = errno; + rv = -1; + munmap(_st_io_uring_data->ring, _st_io_uring_data->params.sq_off.array + _st_io_uring_data->params.sq_entries * sizeof(unsigned)); + goto cleanup_io_uring; + } + + /* Map the completion queue */ + _st_io_uring_data->cqes = (struct io_uring_cqe *)mmap(NULL, _st_io_uring_data->params.cq_off.cqes + _st_io_uring_data->params.cq_entries * sizeof(struct io_uring_cqe), + PROT_READ | PROT_WRITE, + MAP_SHARED | MAP_POPULATE, + _st_io_uring_data->ringfd, + IORING_OFF_CQ_RING); + if (_st_io_uring_data->cqes == MAP_FAILED) { + err = errno; + rv = -1; + munmap(_st_io_uring_data->sqes, _st_io_uring_data->params.sq_entries * sizeof(struct io_uring_sqe)); + munmap(_st_io_uring_data->ring, _st_io_uring_data->params.sq_off.array + _st_io_uring_data->params.sq_entries * sizeof(unsigned)); + goto cleanup_io_uring; + } + + /* + * Allocate file descriptor data array. + * FD_SETSIZE looks like good initial size. + */ + _st_io_uring_data->fd_data_size = FD_SETSIZE; + _st_io_uring_data->fd_data = (_io_uring_fd_data_t *)calloc(_st_io_uring_data->fd_data_size, sizeof(_io_uring_fd_data_t)); + if (!_st_io_uring_data->fd_data) { + err = errno; + rv = -1; + goto cleanup_io_uring; + } + + /* Allocate event lists */ + _st_io_uring_data->ring_size = ST_IO_URING_RING_SIZE; + _st_io_uring_data->sqes = (struct io_uring_sqe *)malloc(_st_io_uring_data->ring_size * sizeof(struct io_uring_sqe)); + if (!_st_io_uring_data->sqes) { + err = ENOMEM; + rv = -1; + } + + cleanup_io_uring: + if (rv < 0) { + if (_st_io_uring_data->ringfd >= 0) { + munmap(_st_io_uring_data->cqes, _st_io_uring_data->params.cq_off.cqes + _st_io_uring_data->params.cq_entries * sizeof(struct io_uring_cqe)); + munmap(_st_io_uring_data->sqes, _st_io_uring_data->params.sq_entries * sizeof(struct io_uring_sqe)); + munmap(_st_io_uring_data->ring, _st_io_uring_data->params.sq_off.array + _st_io_uring_data->params.sq_entries * sizeof(unsigned)); + close(_st_io_uring_data->ringfd); + } + free(_st_io_uring_data->fd_data); + free(_st_io_uring_data->sqes); + free(_st_io_uring_data); + _st_io_uring_data = NULL; + errno = err; + } + + return rv; +} + +ST_HIDDEN int _st_io_uring_fd_data_expand(int maxfd) +{ + _io_uring_fd_data_t *ptr; + int n = _st_io_uring_data->fd_data_size; + + while (maxfd >= n) + n <<= 1; + + ptr = (_io_uring_fd_data_t *)realloc(_st_io_uring_data->fd_data, n * sizeof(_io_uring_fd_data_t)); + if (!ptr) + return -1; + + memset(ptr + _st_io_uring_data->fd_data_size, 0, (n - _st_io_uring_data->fd_data_size) * sizeof(_io_uring_fd_data_t)); + + _st_io_uring_data->fd_data = ptr; + _st_io_uring_data->fd_data_size = n; + + return 0; +} + +ST_HIDDEN int _st_io_uring_addlist_expand(int avail) +{ + struct io_uring_sqe *ptr; + int n = _st_io_uring_data->ring_size; + + while (avail > n - _st_io_uring_data->ring_cnt) + n <<= 1; + + ptr = (struct io_uring_sqe *)realloc(_st_io_uring_data->sqes, n * sizeof(struct io_uring_sqe)); + if (!ptr) + return -1; + + _st_io_uring_data->sqes = ptr; + _st_io_uring_data->ring_size = n; + + /* + * Try to expand the completion queue entries too + * (although we don't have to do it). + */ + struct io_uring_cqe *cqe_ptr = (struct io_uring_cqe *)realloc(_st_io_uring_data->cqes, n * sizeof(struct io_uring_cqe)); + if (cqe_ptr) { + _st_io_uring_data->cqes = cqe_ptr; + } + + return 0; +} + +ST_HIDDEN void _st_io_uring_addlist_add(const struct io_uring_sqe *sqe) +{ + ST_ASSERT(_st_io_uring_data->ring_cnt < _st_io_uring_data->ring_size); + memcpy(_st_io_uring_data->sqes + _st_io_uring_data->ring_cnt, sqe, sizeof(struct io_uring_sqe)); + _st_io_uring_data->ring_cnt++; +} + +ST_HIDDEN void _st_io_uring_dellist_add(const struct io_uring_sqe *sqe) +{ + int n = _st_io_uring_data->ring_size; + + if (_st_io_uring_data->ring_cnt >= n) { + struct io_uring_sqe *ptr; + + n <<= 1; + ptr = (struct io_uring_sqe *)realloc(_st_io_uring_data->sqes, n * sizeof(struct io_uring_sqe)); + if (!ptr) { + /* See comment in _st_io_uring_pollset_del() */ + return; + } + + _st_io_uring_data->sqes = ptr; + _st_io_uring_data->ring_size = n; + } + + memcpy(_st_io_uring_data->sqes + _st_io_uring_data->ring_cnt, sqe, sizeof(struct io_uring_sqe)); + _st_io_uring_data->ring_cnt++; +} + +ST_HIDDEN int _st_io_uring_pollset_add(struct pollfd *pds, int npds) +{ + struct io_uring_sqe sqe; + struct pollfd *pd; + struct pollfd *epd = pds + npds; + + /* + * Pollset adding is "atomic". That is, either it succeeded for + * all descriptors in the set or it failed. It means that we + * need to do all the checks up front so we don't have to + * "unwind" if adding of one of the descriptors failed. + */ + for (pd = pds; pd < epd; pd++) { + /* POLLIN and/or POLLOUT must be set, but nothing else */ + if (pd->fd < 0 || !pd->events || (pd->events & ~(POLLIN | POLLOUT))) { + errno = EINVAL; + return -1; + } + if (pd->fd >= _st_io_uring_data->fd_data_size && + _st_io_uring_fd_data_expand(pd->fd) < 0) + return -1; + } + + /* + * Make sure we have enough room in the addlist for twice as many + * descriptors as in the pollset (for both READ and WRITE filters). + */ + npds <<= 1; + if (npds > _st_io_uring_data->ring_size - _st_io_uring_data->ring_cnt && _st_io_uring_addlist_expand(npds) < 0) + return -1; + + for (pd = pds; pd < epd; pd++) { + if ((pd->events & POLLIN) && (_ST_IO_URING_READ_CNT(pd->fd)++ == 0)) { + memset(&sqe, 0, sizeof(sqe)); + sqe.opcode = IORING_OP_READ; + sqe.fd = pd->fd; + sqe.off = 0; + sqe.addr = (uint64_t)&_st_io_uring_data->fd_data[pd->fd].revents; + sqe.len = sizeof(_st_io_uring_data->fd_data[pd->fd].revents); + _st_io_uring_addlist_add(&sqe); + } + if ((pd->events & POLLOUT) && (_ST_IO_URING_WRITE_CNT(pd->fd)++ == 0)) { + memset(&sqe, 0, sizeof(sqe)); + sqe.opcode = IORING_OP_WRITE; + sqe.fd = pd->fd; + sqe.off = 0; + sqe.addr = (uint64_t)&_st_io_uring_data->fd_data[pd->fd].revents; + sqe.len = sizeof(_st_io_uring_data->fd_data[pd->fd].revents); + _st_io_uring_addlist_add(&sqe); + } + } + + return 0; +} + +ST_HIDDEN void _st_io_uring_pollset_del(struct pollfd *pds, int npds) +{ + struct io_uring_sqe sqe; + struct pollfd *pd; + struct pollfd *epd = pds + npds; + + /* + * It's OK if deleting fails because a descriptor will either be + * closed or deleted in dispatch function after it fires. + */ + _st_io_uring_data->ring_cnt = 0; + for (pd = pds; pd < epd; pd++) { + if ((pd->events & POLLIN) && (--_ST_IO_URING_READ_CNT(pd->fd) == 0)) { + memset(&sqe, 0, sizeof(sqe)); + sqe.opcode = IORING_OP_POLL_REMOVE; + sqe.fd = pd->fd; + sqe.off = 0; + sqe.addr = (uint64_t)&_st_io_uring_data->fd_data[pd->fd].revents; + sqe.len = sizeof(_st_io_uring_data->fd_data[pd->fd].revents); + _st_io_uring_dellist_add(&sqe); + } + if ((pd->events & POLLOUT) && (--_ST_IO_URING_WRITE_CNT(pd->fd) == 0)) { + memset(&sqe, 0, sizeof(sqe)); + sqe.opcode = IORING_OP_POLL_REMOVE; + sqe.fd = pd->fd; + sqe.off = 0; + sqe.addr = (uint64_t)&_st_io_uring_data->fd_data[pd->fd].revents; + sqe.len = sizeof(_st_io_uring_data->fd_data[pd->fd].revents); + _st_io_uring_dellist_add(&sqe); + } + } + + if (_st_io_uring_data->ring_cnt > 0) { + /* + * We do "synchronous" io_uring deletes to avoid deleting + * closed descriptors and other possible problems. + */ + int rv; + do { + /* This io_uring() won't block since result list size is 0 */ + rv = io_uring_submit(_st_io_uring_data->ring); + } while (rv < 0 && errno == EINTR); + } +} + +ST_HIDDEN void _st_io_uring_dispatch(void) +{ + st_utime_t min_timeout; + _st_clist_t *q; + _st_pollq_t *pq; + struct pollfd *pds, *epds; + struct io_uring_cqe *cqe; + int timeout, nfd, osfd, notify; + int events, op; + short revents; + + #if defined(DEBUG) && defined(DEBUG_STATS) + ++_st_stat_io_uring; + #endif + + if (_ST_SLEEPQ == NULL) { + timeout = -1; + } else { + min_timeout = (_ST_SLEEPQ->due <= _ST_LAST_CLOCK) ? 0 : (_ST_SLEEPQ->due - _ST_LAST_CLOCK); + timeout = (int) (min_timeout / 1000); + + // At least wait 1ms when <1ms, to avoid io_uring spin loop. + if (timeout == 0) { + #if defined(DEBUG) && defined(DEBUG_STATS) + ++_st_stat_io_uring_zero; + #endif + + if (min_timeout > 0) { + #if defined(DEBUG) && defined(DEBUG_STATS) + ++_st_stat_io_uring_shake; + #endif + + timeout = 1; + } + } + } + + /* Check for I/O operations */ + nfd = io_uring_wait_cqe(_st_io_uring_data->ring, &cqe); + + #if defined(DEBUG) && defined(DEBUG_STATS) + if (nfd <= 0) { + ++_st_stat_io_uring_spin; + } + #endif + + if (nfd > 0) { + /* 处理单个完成事件 */ + osfd = cqe->user_data; + _ST_IO_URING_REVENTS(osfd) = cqe->res; + if (_ST_IO_URING_REVENTS(osfd) & (IORING_POLL_ADD_MULTI)) { + /* Also set I/O bits on error */ + _ST_IO_URING_REVENTS(osfd) |= _ST_IO_URING_EVENTS(osfd); + } + + for (q = _ST_IOQ.next; q != &_ST_IOQ; q = q->next) { + pq = _ST_POLLQUEUE_PTR(q); + notify = 0; + epds = pq->pds + pq->npds; + + for (pds = pq->pds; pds < epds; pds++) { + if (_ST_IO_URING_REVENTS(pds->fd) == 0) { + pds->revents = 0; + continue; + } + osfd = pds->fd; + events = pds->events; + revents = 0; + if ((events & POLLIN) && (_ST_IO_URING_REVENTS(osfd) & IORING_POLL_ADD_MULTI)) + revents |= POLLIN; + if ((events & POLLOUT) && (_ST_IO_URING_REVENTS(osfd) & IORING_POLL_ADD_MULTI)) + revents |= POLLOUT; + pds->revents = revents; + if (revents) { + notify = 1; + } + } + if (notify) { + ST_REMOVE_LINK(&pq->links); + pq->on_ioq = 0; + /* + * Here we will only delete/modify descriptors that + * didn't fire (see comments in _st_io_uring_pollset_del()). + */ + _st_io_uring_pollset_del(pq->pds, pq->npds); + + if (pq->thread->flags & _ST_FL_ON_SLEEPQ) + _ST_DEL_SLEEPQ(pq->thread); + pq->thread->state = _ST_ST_RUNNABLE; + _ST_ADD_RUNQ(pq->thread); + } + } + + /* 处理单个完成事件 */ + osfd = cqe->user_data; + _ST_IO_URING_REVENTS(osfd) = 0; + events = _ST_IO_URING_EVENTS(osfd); + op = events ? IORING_OP_POLL_ADD : IORING_OP_POLL_REMOVE; + cqe->user_data = osfd; + if (io_uring_submit(_st_io_uring_data->ring) == 0 && op == IORING_OP_POLL_REMOVE) { + _st_io_uring_data->ring_cnt--; + } + + /* 告知io_uring我们已经处理了这个完成事件 */ + io_uring_cqe_seen(_st_io_uring_data->ring, cqe); + } +} + +ST_HIDDEN int _st_io_uring_fd_new(int osfd) +{ + if (osfd >= _st_io_uring_data->fd_data_size && _st_io_uring_fd_data_expand(osfd) < 0) + return -1; + + return 0; +} + +ST_HIDDEN int _st_io_uring_fd_close(int osfd) +{ + if (_ST_IO_URING_READ_CNT(osfd) || _ST_IO_URING_WRITE_CNT(osfd)) { + errno = EBUSY; + return -1; + } + + return 0; +} + +ST_HIDDEN int _st_io_uring_fd_getlimit(void) +{ + /* zero means no specific limit */ + return 0; +} + +/* + * Check if io_uring functions are just stubs. + */ +ST_HIDDEN int _st_io_uring_is_supported(void) +{ + struct io_uring_params params; + int fd; + + memset(¶ms, 0, sizeof(params)); + fd = io_uring_setup(1, ¶ms); + if (fd < 0) { + return (errno != ENOSYS); + } + close(fd); + return 1; +} + +ST_HIDDEN void _st_io_uring_destroy(void) +{ + if (_st_io_uring_data->ringfd >= 0) { + munmap(_st_io_uring_data->cqes, _st_io_uring_data->params.cq_off.cqes + _st_io_uring_data->params.cq_entries * sizeof(struct io_uring_cqe)); + munmap(_st_io_uring_data->sqes, _st_io_uring_data->params.sq_entries * sizeof(struct io_uring_sqe)); + munmap(_st_io_uring_data->ring, _st_io_uring_data->params.sq_off.array + _st_io_uring_data->params.sq_entries * sizeof(unsigned)); + close(_st_io_uring_data->ringfd); + } + free(_st_io_uring_data->fd_data); + free(_st_io_uring_data); + _st_io_uring_data = NULL; +} + +static _st_eventsys_t _st_io_uring_eventsys = { + "io_uring", + ST_EVENTSYS_ALT, + _st_io_uring_init, + _st_io_uring_dispatch, + _st_io_uring_pollset_add, + _st_io_uring_pollset_del, + _st_io_uring_fd_new, + _st_io_uring_fd_close, + _st_io_uring_fd_getlimit, + _st_io_uring_destroy +}; +#endif /* MD_HAVE_IO_URING */ + + /***************************************** * Public functions */ @@ -1241,26 +1718,45 @@ int st_set_eventsys(int eventsys) return -1; } - if (eventsys == ST_EVENTSYS_SELECT || eventsys == ST_EVENTSYS_DEFAULT) { -#if defined (MD_HAVE_SELECT) + if (eventsys == ST_EVENTSYS_SELECT) { +#ifdef MD_HAVE_SELECT _st_eventsys = &_st_select_eventsys; return 0; #endif + return -1; } - if (eventsys == ST_EVENTSYS_ALT) { -#if defined (MD_HAVE_KQUEUE) + /* For ST_EVENTSYS_DEFAULT and ST_EVENTSYS_ALT, try each event system in order of preference */ +#ifdef MD_HAVE_KQUEUE + if (eventsys == ST_EVENTSYS_DEFAULT || eventsys == ST_EVENTSYS_ALT) { _st_eventsys = &_st_kq_eventsys; return 0; -#elif defined (MD_HAVE_EPOLL) - if (_st_epoll_is_supported()) { - _st_eventsys = &_st_epoll_eventsys; + } +#endif + +#ifdef MD_HAVE_EPOLL + if (eventsys == ST_EVENTSYS_DEFAULT || eventsys == ST_EVENTSYS_ALT) { + _st_eventsys = &_st_epoll_eventsys; + return 0; + } +#endif + +#ifdef MD_HAVE_IO_URING + if (eventsys == ST_EVENTSYS_DEFAULT || eventsys == ST_EVENTSYS_ALT) { + if (_st_io_uring_is_supported()) { + _st_eventsys = &_st_io_uring_eventsys; return 0; } + } #endif + +#ifdef MD_HAVE_SELECT + if (eventsys == ST_EVENTSYS_DEFAULT || eventsys == ST_EVENTSYS_ALT) { + _st_eventsys = &_st_select_eventsys; + return 0; } +#endif - errno = EINVAL; return -1; } diff --git a/md.h b/md.h index d3158db..02bd552 100644 --- a/md.h +++ b/md.h @@ -170,6 +170,13 @@ extern void _st_md_cxt_restore(_st_jmp_buf_t env, int val); MD_GET_SP(_thread) = (long) (_sp); \ ST_END_MACRO + /* + * Check for io_uring support + */ + #if defined(__has_include) && __has_include() && !defined(MD_HAVE_IO_URING) + #define MD_HAVE_IO_URING + #endif + #elif defined (CYGWIN64) // For CYGWIN64, build SRS on Windows. diff --git a/public.h b/public.h index 97da146..38c4291 100644 --- a/public.h +++ b/public.h @@ -69,6 +69,7 @@ #define ST_EVENTSYS_DEFAULT 0 #define ST_EVENTSYS_SELECT 1 #define ST_EVENTSYS_ALT 3 +#define ST_EVENTSYS_IO_URING 4 #ifdef __cplusplus extern "C" { diff --git a/utest/Makefile b/utest/Makefile index 2cd304b..4e43484 100644 --- a/utest/Makefile +++ b/utest/Makefile @@ -1,4 +1,3 @@ - # The main dir of st. ST_DIR = .. # The main dir of st utest. @@ -12,6 +11,11 @@ CXXFLAGS += -DGTEST_USE_OWN_TR1_TUPLE=1 # Flags for warnings. WARNFLAGS += -Wall -Wno-deprecated-declarations -Wno-unused-private-field -Wno-unused-command-line-argument +# Add liburing support +CXXFLAGS += -DMD_HAVE_IO_URING +LDFLAGS += -L/usr/local/lib -L/usr/lib +LIBS += -luring + # House-keeping build targets. all : $(ST_DIR)/obj/st_utest @@ -69,5 +73,5 @@ $(ST_DIR)/obj/%.o : %.cpp $(ST_UTEST_DEPS) $(UTEST_DEPS) # Generate the utest binary $(ST_DIR)/obj/st_utest : $(OBJECTS) $(ST_DIR)/obj/gtest.a $(ST_UTEST_DEPS) - $(CXX) -o $@ $(CXXFLAGS) $(UTEST_FLAGS) \ - -lpthread -ldl $^ + $(CXX) -o $@ $(CXXFLAGS) $(UTEST_FLAGS) $(LDFLAGS) \ + $(OBJECTS) $(ST_DIR)/obj/gtest.a $(ST_UTEST_DEPS) -lpthread -ldl $(LIBS) diff --git a/utest/st_utest_tcp_uring.cpp b/utest/st_utest_tcp_uring.cpp new file mode 100644 index 0000000..e81a3a1 --- /dev/null +++ b/utest/st_utest_tcp_uring.cpp @@ -0,0 +1,635 @@ +/* SPDX-License-Identifier: MIT */ +/* Copyright (c) 2013-2024 The SRS Authors */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include // For rlimit +#include // For std::min + +// Client thread function +static void* client_thread_func(void* arg) { + printf("[CLIENT] Starting client thread...\n"); + uint16_t port = *(uint16_t*)arg; + delete (uint16_t*)arg; + printf("[CLIENT] Connecting to port %d\n", port); + + // Create client socket + int client_fd = socket(AF_INET, SOCK_STREAM, 0); + if (client_fd == -1) { + perror("[CLIENT] socket failed"); + return nullptr; + } + printf("[CLIENT] Socket created successfully\n"); + + // Create state-threads file descriptor + st_netfd_t client_nfd = st_netfd_open_socket(client_fd); + if (!client_nfd) { + perror("[CLIENT] st_netfd_open_socket failed"); + close(client_fd); + return nullptr; + } + printf("[CLIENT] State-threads file descriptor created successfully\n"); + + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = inet_addr("127.0.0.1"); + server_addr.sin_port = htons(port); + + // Connect to server using state-threads + printf("[CLIENT] Attempting to connect to server...\n"); + if (st_connect(client_nfd, (struct sockaddr*)&server_addr, sizeof(server_addr), ST_UTIME_NO_TIMEOUT) != 0) { + perror("[CLIENT] connect failed"); + st_netfd_close(client_nfd); + return nullptr; + } + printf("[CLIENT] Connected to server successfully\n"); + + // Send data using state-threads + const char* msg = "Hello from client!"; + printf("[CLIENT] Sending message: %s\n", msg); + ssize_t n = st_write(client_nfd, msg, strlen(msg), ST_UTIME_NO_TIMEOUT); + if (n != static_cast(strlen(msg))) { + perror("[CLIENT] send failed"); + st_netfd_close(client_nfd); + return nullptr; + } + printf("[CLIENT] Message sent successfully\n"); + + // Receive response using state-threads + char buf[1024]; + printf("[CLIENT] Waiting for server response...\n"); + n = st_read(client_nfd, buf, sizeof(buf) - 1, ST_UTIME_NO_TIMEOUT); + if (n <= 0) { + perror("[CLIENT] recv failed"); + st_netfd_close(client_nfd); + return nullptr; + } + buf[n] = '\0'; + printf("[CLIENT] Received response: %s\n", buf); + if (strcmp(buf, "Hello from server!") != 0) { + printf("[CLIENT] Unexpected response received\n"); + st_netfd_close(client_nfd); + return nullptr; + } + printf("[CLIENT] Response verified successfully\n"); + + st_netfd_close(client_nfd); + printf("[CLIENT] Client thread finished\n"); + return nullptr; +} + +// Test io_uring TCP server and client +VOID TEST(IoUringTest, TcpServerClient) +{ + printf("\n[TEST] Starting TcpServerClient test...\n"); + + // Initialize state-threads + int rv = st_init(); + EXPECT_EQ(rv, 0); + + // Try to set event system to io_uring + printf("[TEST] Attempting to set event system to io_uring...\n"); + rv = st_set_eventsys(ST_EVENTSYS_IO_URING); + if (rv != 0) { + printf("[TEST] io_uring event system not available, using default event system\n"); + // The default event system is already selected by st_init() + // No need to call st_set_eventsys again + } + + printf("[TEST] State-threads initialized with %s event system\n", st_get_eventsys_name()); + + // Create server socket + int server_fd = socket(AF_INET, SOCK_STREAM, 0); + EXPECT_NE(server_fd, -1); + printf("[TEST] Server socket created successfully\n"); + + // Create state-threads file descriptor + st_netfd_t server_nfd = st_netfd_open_socket(server_fd); + EXPECT_NE(server_nfd, nullptr); + printf("[TEST] State-threads file descriptor created successfully\n"); + + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = INADDR_ANY; + server_addr.sin_port = htons(0); // Let system choose port + + EXPECT_EQ(bind(st_netfd_fileno(server_nfd), (struct sockaddr*)&server_addr, sizeof(server_addr)), 0); + EXPECT_EQ(listen(st_netfd_fileno(server_nfd), 1), 0); + printf("[TEST] Server bound and listening\n"); + + // Get the port number + socklen_t addr_len = sizeof(server_addr); + EXPECT_EQ(getsockname(st_netfd_fileno(server_nfd), (struct sockaddr*)&server_addr, &addr_len), 0); + uint16_t port = ntohs(server_addr.sin_port); + printf("[TEST] Server listening on port %d\n", port); + + // Create a client thread + printf("[TEST] Creating client thread...\n"); + st_thread_t client_thread = st_thread_create(client_thread_func, new uint16_t(port), 0, 0); + EXPECT_NE(client_thread, nullptr); + printf("[TEST] Client thread created successfully\n"); + + // Accept client connection using state-threads + printf("[TEST] Waiting for client connection...\n"); + st_netfd_t client_nfd = st_accept(server_nfd, nullptr, nullptr, ST_UTIME_NO_TIMEOUT); + EXPECT_NE(client_nfd, nullptr); + printf("[TEST] Client connection accepted\n"); + + // Receive data using state-threads + char buf[1024]; + printf("[TEST] Waiting for client data...\n"); + ssize_t n = st_read(client_nfd, buf, sizeof(buf), ST_UTIME_NO_TIMEOUT); + EXPECT_GT(n, 0); + buf[n] = '\0'; + EXPECT_STREQ(buf, "Hello from client!"); + printf("[TEST] Received message: %s\n", buf); + + // Send response using state-threads + const char* msg = "Hello from server!"; + EXPECT_EQ(st_write(client_nfd, msg, strlen(msg), ST_UTIME_NO_TIMEOUT), static_cast(strlen(msg))); + printf("[TEST] Sent response: %s\n", msg); + + // Wait for client thread to finish + printf("[TEST] Waiting for client thread to finish...\n"); + st_thread_join(client_thread, nullptr); + printf("[TEST] Client thread finished\n"); + + // Cleanup + st_netfd_close(client_nfd); + st_netfd_close(server_nfd); + printf("[TEST] Test completed successfully\n"); +} + +// Client thread function for multiple clients +static void* multi_client_thread_func(void* arg) { + uint16_t port = *(uint16_t*)arg; + delete (uint16_t*)arg; + + // Create client socket + int client_fd = socket(AF_INET, SOCK_STREAM, 0); + if (client_fd == -1) { + return nullptr; + } + + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = inet_addr("127.0.0.1"); + server_addr.sin_port = htons(port); + + // Connect to server + if (connect(client_fd, (struct sockaddr*)&server_addr, sizeof(server_addr)) != 0) { + close(client_fd); + return nullptr; + } + + // Send data + char msg[64]; + snprintf(msg, sizeof(msg), "Hello from client %d!", getpid()); + if (send(client_fd, msg, strlen(msg), 0) != static_cast(strlen(msg))) { + perror("send failed"); + close(client_fd); + return nullptr; + } + + // Receive response + char buf[1024]; + ssize_t n = recv(client_fd, buf, sizeof(buf) - 1, 0); + if (n <= 0) { + perror("recv failed"); + close(client_fd); + return nullptr; + } + buf[n] = '\0'; + if (strcmp(buf, "Hello from server!") != 0) { + close(client_fd); + return nullptr; + } + + close(client_fd); + return nullptr; +} + +// Test io_uring TCP server with multiple clients +VOID TEST(IoUringTest, TcpServerMultiClients) +{ + printf("\n[TEST] Starting TcpServerMultiClients test...\n"); + + // Initialize state-threads + int rv = st_init(); + EXPECT_EQ(rv, 0); + + // Try to set event system to io_uring + printf("[TEST] Attempting to set event system to io_uring...\n"); + rv = st_set_eventsys(ST_EVENTSYS_IO_URING); + if (rv != 0) { + printf("[TEST] io_uring event system not available, using default event system\n"); + // The default event system is already selected by st_init() + // No need to call st_set_eventsys again + } + + printf("[TEST] State-threads initialized with %s event system\n", st_get_eventsys_name()); + + // Create server socket + int server_fd = socket(AF_INET, SOCK_STREAM, 0); + EXPECT_NE(server_fd, -1); + printf("[TEST] Server socket created successfully\n"); + + // Create state-threads file descriptor + st_netfd_t server_nfd = st_netfd_open_socket(server_fd); + EXPECT_NE(server_nfd, nullptr); + printf("[TEST] State-threads file descriptor created successfully\n"); + + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = INADDR_ANY; + server_addr.sin_port = htons(0); // Let system choose port + + EXPECT_EQ(bind(st_netfd_fileno(server_nfd), (struct sockaddr*)&server_addr, sizeof(server_addr)), 0); + EXPECT_EQ(listen(st_netfd_fileno(server_nfd), 5), 0); + printf("[TEST] Server bound and listening\n"); + + // Get the port number + socklen_t addr_len = sizeof(server_addr); + EXPECT_EQ(getsockname(st_netfd_fileno(server_nfd), (struct sockaddr*)&server_addr, &addr_len), 0); + uint16_t port = ntohs(server_addr.sin_port); + printf("[TEST] Server listening on port %d\n", port); + + // Create multiple client threads + const int num_clients = 3; + std::vector client_threads; + printf("[TEST] Creating %d client threads...\n", num_clients); + for (int i = 0; i < num_clients; i++) { + st_thread_t thread = st_thread_create(client_thread_func, new uint16_t(port), 0, 0); + EXPECT_NE(thread, nullptr); + client_threads.push_back(thread); + } + printf("[TEST] All client threads created successfully\n"); + + // Accept all client connections + std::vector client_nfds; + for (int i = 0; i < num_clients; i++) { + printf("[TEST] Waiting for client %d connection...\n", i + 1); + st_netfd_t client_nfd = st_accept(server_nfd, nullptr, nullptr, ST_UTIME_NO_TIMEOUT); + EXPECT_NE(client_nfd, nullptr); + client_nfds.push_back(client_nfd); + printf("[TEST] Client %d connection accepted\n", i + 1); + + // Receive data from client + char buf[1024]; + ssize_t n = st_read(client_nfd, buf, sizeof(buf), ST_UTIME_NO_TIMEOUT); + EXPECT_GT(n, 0); + buf[n] = '\0'; + printf("[TEST] Received from client %d: %s\n", i + 1, buf); + + // Send response to client + const char* msg = "Hello from server!"; + EXPECT_EQ(st_write(client_nfd, msg, strlen(msg), ST_UTIME_NO_TIMEOUT), static_cast(strlen(msg))); + printf("[TEST] Sent response to client %d\n", i + 1); + } + + // Wait for all client threads to finish + printf("[TEST] Waiting for all client threads to finish...\n"); + for (st_thread_t thread : client_threads) { + st_thread_join(thread, nullptr); + } + printf("[TEST] All client threads finished\n"); + + // Cleanup + for (st_netfd_t client_nfd : client_nfds) { + st_netfd_close(client_nfd); + } + st_netfd_close(server_nfd); + printf("[TEST] Test completed successfully\n"); +} + +// Client thread function for stress test +static void* stress_client_thread_func(void* arg) { + uint16_t port = *(uint16_t*)arg; + delete (uint16_t*)arg; + printf("[STRESS_CLIENT] Starting client thread for port %d\n", port); + + // Create client socket + int client_fd = socket(AF_INET, SOCK_STREAM, 0); + if (client_fd == -1) { + perror("[STRESS_CLIENT] socket failed"); + return nullptr; + } + printf("[STRESS_CLIENT] Socket created successfully\n"); + + // Create state-threads file descriptor + st_netfd_t client_nfd = st_netfd_open_socket(client_fd); + if (!client_nfd) { + perror("[STRESS_CLIENT] st_netfd_open_socket failed"); + close(client_fd); + return nullptr; + } + printf("[STRESS_CLIENT] State-threads file descriptor created successfully\n"); + + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = inet_addr("127.0.0.1"); + server_addr.sin_port = htons(port); + + // Connect to server using state-threads + printf("[STRESS_CLIENT] Attempting to connect to server...\n"); + if (st_connect(client_nfd, (struct sockaddr*)&server_addr, sizeof(server_addr), ST_UTIME_NO_TIMEOUT) != 0) { + perror("[STRESS_CLIENT] connect failed"); + st_netfd_close(client_nfd); + return nullptr; + } + printf("[STRESS_CLIENT] Connected to server successfully\n"); + + // Send data using state-threads + const char* msg = "Hello from stress test client!"; + printf("[STRESS_CLIENT] Sending message...\n"); + ssize_t n = st_write(client_nfd, msg, strlen(msg), ST_UTIME_NO_TIMEOUT); + if (n != static_cast(strlen(msg))) { + perror("[STRESS_CLIENT] send failed"); + st_netfd_close(client_nfd); + return nullptr; + } + printf("[STRESS_CLIENT] Message sent successfully\n"); + + // Receive response using state-threads + char buf[1024]; + printf("[STRESS_CLIENT] Waiting for server response...\n"); + n = st_read(client_nfd, buf, sizeof(buf) - 1, ST_UTIME_NO_TIMEOUT); + if (n <= 0) { + perror("[STRESS_CLIENT] recv failed"); + st_netfd_close(client_nfd); + return nullptr; + } + buf[n] = '\0'; + printf("[STRESS_CLIENT] Received response: %s\n", buf); + + st_netfd_close(client_nfd); + printf("[STRESS_CLIENT] Client thread finished\n"); + return nullptr; +} + +// Function to get CPU usage with more detail +static std::pair get_detailed_cpu_usage() { + struct rusage usage; + if (getrusage(RUSAGE_SELF, &usage) != 0) { + return std::make_pair(-1.0, std::string("Error getting CPU usage")); + } + + double user_time = usage.ru_utime.tv_sec + usage.ru_utime.tv_usec / 1000000.0; + double sys_time = usage.ru_stime.tv_sec + usage.ru_stime.tv_usec / 1000000.0; + double total_time = user_time + sys_time; + + char details[256]; + snprintf(details, sizeof(details), + "User: %.2f%%, System: %.2f%%, Total: %.2f%%, MaxRSS: %ld KB", + user_time * 100.0, sys_time * 100.0, total_time * 100.0, + usage.ru_maxrss); + + return std::make_pair(total_time * 100.0, std::string(details)); +} + +// Function to get CPU usage (simplified version for backward compatibility) +static double get_cpu_usage() { + return get_detailed_cpu_usage().first; +} + +// Server thread function to handle client connections +static void* server_handler_thread(void* arg) { + st_netfd_t client_nfd = (st_netfd_t)arg; + if (!client_nfd) { + printf("[SERVER] Error: Invalid file descriptor passed to handler\n"); + return nullptr; + } + + // Handle the connection + char buf[1024]; + ssize_t n = st_read(client_nfd, buf, sizeof(buf), ST_UTIME_NO_TIMEOUT); + if (n > 0) { + buf[n] = '\0'; + // Send response + const char* msg = "Hello from server!"; + st_write(client_nfd, msg, strlen(msg), ST_UTIME_NO_TIMEOUT); + } + + // Don't close the client_nfd here, let the main thread handle it + return nullptr; +} + +// Test io_uring TCP server with stress test for 10000 connections +VOID TEST(IoUringTest, TcpServer10KConnections) +{ + printf("\n[TEST] Starting TcpServer10KConnections test...\n"); + + // Initialize state-threads + int rv = st_init(); + EXPECT_EQ(rv, 0); + + // Try to set event system to io_uring + printf("[TEST] Attempting to set event system to io_uring...\n"); + rv = st_set_eventsys(ST_EVENTSYS_IO_URING); + if (rv != 0) { + printf("[TEST] io_uring event system not available, using default event system\n"); + // The default event system is already selected by st_init() + // No need to call st_set_eventsys again + } + + printf("[TEST] State-threads initialized with %s event system\n", st_get_eventsys_name()); + + // Create server socket + int server_fd = socket(AF_INET, SOCK_STREAM, 0); + EXPECT_NE(server_fd, -1); + + // Enable address reuse + int opt = 1; + setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + + // Create state-threads file descriptor + st_netfd_t server_nfd = st_netfd_open_socket(server_fd); + EXPECT_NE(server_nfd, nullptr); + + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = INADDR_ANY; + server_addr.sin_port = htons(0); // Let system choose port + + EXPECT_EQ(bind(st_netfd_fileno(server_nfd), (struct sockaddr*)&server_addr, sizeof(server_addr)), 0); + EXPECT_EQ(listen(st_netfd_fileno(server_nfd), 128), 0); // Use a more modest backlog + printf("[TEST] Server bound and listening\n"); + + // Get the port number + socklen_t addr_len = sizeof(server_addr); + EXPECT_EQ(getsockname(st_netfd_fileno(server_nfd), (struct sockaddr*)&server_addr, &addr_len), 0); + uint16_t port = ntohs(server_addr.sin_port); + printf("[TEST] Server listening on port %d\n", port); + + // Get system limits + struct rlimit rlim; + EXPECT_EQ(getrlimit(RLIMIT_NOFILE, &rlim), 0); + printf("[TEST] System file descriptor limit: %lu\n", (unsigned long)rlim.rlim_cur); + + // Calculate maximum number of connections - use a much smaller number for safety + const int target_connections = 100; // Start with 100 instead of 10000 + const int max_connections = std::min(target_connections, static_cast(rlim.rlim_cur - 200)); + printf("[TEST] Will test up to %d connections\n", max_connections); + + // Create vectors to store client threads and server-side file descriptors + std::vector client_threads; + std::vector server_threads; + std::vector client_nfds; + + client_threads.reserve(max_connections); + server_threads.reserve(max_connections); + client_nfds.reserve(max_connections); + + // Start stress test + int successful_connections = 0; + int batch_size = 500; // Larger batch size for efficiency + int current_batch = 0; + const double CPU_THRESHOLD = 90.0; + + // Record start time + time_t start_time = time(NULL); + + // Print table header + printf("\n%-10s %-15s %-25s %-20s\n", "Batch", "Connections", "CPU Usage", "Time Elapsed"); + printf("------------------------------------------------------------------\n"); + + while (successful_connections < max_connections) { + current_batch++; + + // Check CPU usage before starting new batch + std::pair cpu_info = get_detailed_cpu_usage(); + double cpu_percentage = cpu_info.first; + + if (cpu_percentage > CPU_THRESHOLD) { + printf("[TEST] CPU usage too high (%.2f%%), stopping test\n", cpu_percentage); + break; + } + + // Create a batch of client threads + int threads_created = 0; + int batch_target = std::min(batch_size, max_connections - successful_connections); + + for (int i = 0; i < batch_target; i++) { + uint16_t* port_arg = new uint16_t(port); + st_thread_t thread = st_thread_create(stress_client_thread_func, port_arg, 0, 0); + if (thread) { + client_threads.push_back(thread); + threads_created++; + } else { + delete port_arg; // Clean up if thread creation failed + printf("[TEST] Failed to create client thread\n"); + } + } + + printf("[TEST] Created %d client threads in this batch\n", threads_created); + + // Wait a bit to ensure threads have time to start + st_sleep(1); + + // Accept connections and handle them + int connections_accepted = 0; + for (int i = 0; i < threads_created; i++) { + // Accept with timeout + st_netfd_t client_nfd = st_accept(server_nfd, nullptr, nullptr, ST_UTIME_NO_TIMEOUT); + if (client_nfd) { + client_nfds.push_back(client_nfd); + + // Create a dedicated thread to handle this connection + st_thread_t server_thread = st_thread_create(server_handler_thread, + (void*)client_nfd, 0, 0); + if (server_thread) { + server_threads.push_back(server_thread); + successful_connections++; + connections_accepted++; + } else { + printf("[TEST] Failed to create server handler thread\n"); + // Don't increment counters if thread creation failed + } + } else { + printf("[TEST] Failed to accept client connection\n"); + break; + } + } + + printf("[TEST] Accepted %d connections in this batch\n", connections_accepted); + + // Print progress and CPU usage + std::pair current_cpu_info = get_detailed_cpu_usage(); + double current_cpu = current_cpu_info.first; + std::string current_details = current_cpu_info.second; + + time_t elapsed = time(NULL) - start_time; + printf("%-10d %-15d %-25.2f %-20ld\n", + current_batch, successful_connections, current_cpu, elapsed); + + // Print detailed CPU information + printf("[DETAIL] %s\n", current_details.c_str()); + + // Wait for client threads in this batch to finish + printf("[TEST] Waiting for client threads in this batch to finish...\n"); + for (int i = client_threads.size() - threads_created; i < (int)client_threads.size(); i++) { + if (i >= 0 && i < (int)client_threads.size()) { + st_thread_join(client_threads[i], nullptr); + } + } + + // If we couldn't accept all connections in this batch, we've hit a limit + if (connections_accepted < threads_created) { + printf("[TEST] Reached connection limit - accepting stopped at %d connections\n", + successful_connections); + break; + } + + // Sleep between batches to allow system to stabilize + st_sleep(1); + } + + // Wait for all server threads to finish + printf("[TEST] Waiting for server threads to finish...\n"); + for (auto& thread : server_threads) { + if (thread) { + st_thread_join(thread, nullptr); + } + } + + // Final statistics + std::pair final_stats = get_detailed_cpu_usage(); + double final_cpu = final_stats.first; + std::string cpu_details = final_stats.second; + time_t total_time = time(NULL) - start_time; + + printf("\n[TEST] Stress test completed\n"); + printf("[TEST] Total successful connections: %d\n", successful_connections); + printf("[TEST] Total time: %ld seconds\n", total_time); + printf("[TEST] Connections per second: %.2f\n", + static_cast(successful_connections) / (total_time > 0 ? total_time : 1)); + printf("[TEST] Final CPU usage: %s\n", cpu_details.c_str()); + + // Cleanup + printf("[TEST] Cleaning up connections...\n"); + for (st_netfd_t client_nfd : client_nfds) { + if (client_nfd) { + st_netfd_close(client_nfd); + } + } + + if (server_nfd) { + st_netfd_close(server_nfd); + } + + printf("[TEST] Test completed successfully\n"); +} \ No newline at end of file