From e846b1a93ecf096164ff2c4faaf4a89c24a0e76b Mon Sep 17 00:00:00 2001 From: Rusty Russell Date: Fri, 16 Jun 2017 13:17:32 +0930 Subject: [PATCH] io: allow overriding poll function. Signed-off-by: Rusty Russell --- ccan/io/backend.h | 1 - ccan/io/io.h | 11 +++++ ccan/io/poll.c | 10 ++++- ccan/io/test/run-41-io_poll_override.c | 57 ++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 ccan/io/test/run-41-io_poll_override.c diff --git a/ccan/io/backend.h b/ccan/io/backend.h index 3e158a36..c8ceb4e8 100644 --- a/ccan/io/backend.h +++ b/ccan/io/backend.h @@ -2,7 +2,6 @@ #ifndef CCAN_IO_BACKEND_H #define CCAN_IO_BACKEND_H #include -#include #include "io_plan.h" #include diff --git a/ccan/io/io.h b/ccan/io/io.h index fe42b537..11eeade6 100644 --- a/ccan/io/io.h +++ b/ccan/io/io.h @@ -4,6 +4,7 @@ #include #include #include +#include #include struct timers; @@ -701,4 +702,14 @@ bool io_flush_sync(struct io_conn *conn); */ struct timemono (*io_time_override(struct timemono (*now)(void)))(void); +/** + * io_poll_override - override the normal call for poll. + * @pollfn: the function to call. + * + * io usually uses poll() internally, but this forces it to use your + * function (eg. for debugging, suppressing fds, or polling on others unknown + * to ccan/io). Returns the old one. + */ +int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int); + #endif /* CCAN_IO_H */ diff --git a/ccan/io/poll.c b/ccan/io/poll.c index 043feff7..a4e83ed7 100644 --- a/ccan/io/poll.c +++ b/ccan/io/poll.c @@ -17,6 +17,7 @@ static struct fd **fds = NULL; static LIST_HEAD(closing); static LIST_HEAD(always); static struct timemono (*nowfn)(void) = time_mono; +static int (*pollfn)(struct pollfd *fds, nfds_t nfds, int timeout) = poll; struct timemono (*io_time_override(struct timemono (*now)(void)))(void) { @@ -25,6 +26,13 @@ struct timemono (*io_time_override(struct timemono (*now)(void)))(void) return old; } +int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int) +{ + int (*old)(struct pollfd *fds, nfds_t nfds, int timeout) = pollfn; + pollfn = poll; + return old; +} + static bool add_fd(struct fd *fd, short events) { if (!max_fds) { @@ -269,7 +277,7 @@ void *io_loop(struct timers *timers, struct timer **expired) } } - r = poll(pollfds, num_fds, ms_timeout); + r = pollfn(pollfds, num_fds, ms_timeout); if (r < 0) break; diff --git a/ccan/io/test/run-41-io_poll_override.c b/ccan/io/test/run-41-io_poll_override.c new file mode 100644 index 00000000..0a62e2d3 --- /dev/null +++ b/ccan/io/test/run-41-io_poll_override.c @@ -0,0 +1,57 @@ +#include +/* Include the C files directly. */ +#include +#include +#include +#include +#include + +#define PORT "65020" + +/* Should be looking to read from one fd. */ +static int mypoll(struct pollfd *fds, nfds_t nfds, int timeout) +{ + ok1(nfds == 1); + ok1(fds[0].fd >= 0); + ok1(fds[0].events & POLLIN); + ok1(!(fds[0].events & POLLOUT)); + + /* Pretend it's readable. */ + fds[0].revents = POLLIN; + return 1; +} + +static int check_cant_read(int fd, struct io_plan_arg *arg) +{ + char c; + ssize_t ret = read(fd, &c, 1); + + ok1(errno == EAGAIN || errno == EWOULDBLOCK); + ok1(ret == -1); + + return 1; +} + +static struct io_plan *setup_conn(struct io_conn *conn, void *unused) +{ + /* Need to get this to mark it IO_POLLING */ + io_plan_arg(conn, IO_IN); + return io_set_plan(conn, IO_IN, check_cant_read, io_close_cb, NULL); +} + +int main(void) +{ + int fds[2]; + + plan_tests(8); + + pipe(fds); + ok1(io_poll_override(mypoll) == poll); + + io_new_conn(NULL, fds[0], setup_conn, NULL); + ok1(io_loop(NULL, NULL) == NULL); + close(fds[1]); + + /* This exits depending on whether all tests passed */ + return exit_status(); +} -- 2.39.2