]> git.ozlabs.org Git - ccan/blobdiff - ccan/io/poll.c
io: allow overriding poll function.
[ccan] / ccan / io / poll.c
index 98a64f42fbe4336283be8f5371688da439943ac6..a4e83ed761e77251767b01b3e3f5c5dd3492bfcd 100644 (file)
@@ -17,6 +17,7 @@ static struct fd **fds = NULL;
 static LIST_HEAD(closing);
 static LIST_HEAD(always);
 static struct timemono (*nowfn)(void) = time_mono;
 static LIST_HEAD(closing);
 static LIST_HEAD(always);
 static struct timemono (*nowfn)(void) = time_mono;
+static int (*pollfn)(struct pollfd *fds, nfds_t nfds, int timeout) = poll;
 
 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
 {
 
 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
 {
@@ -25,6 +26,13 @@ struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
        return old;
 }
 
        return old;
 }
 
+int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int)
+{
+       int (*old)(struct pollfd *fds, nfds_t nfds, int timeout) = pollfn;
+       pollfn = poll;
+       return old;
+}
+
 static bool add_fd(struct fd *fd, short events)
 {
        if (!max_fds) {
 static bool add_fd(struct fd *fd, short events)
 {
        if (!max_fds) {
@@ -187,11 +195,10 @@ bool add_conn(struct io_conn *c)
        return true;
 }
 
        return true;
 }
 
-struct io_plan *io_close_taken_fd(struct io_conn *conn)
+void cleanup_conn_without_close(struct io_conn *conn)
 {
        tal_del_destructor(conn, destroy_conn_close_fd);
        destroy_conn(conn, false);
 {
        tal_del_destructor(conn, destroy_conn_close_fd);
        destroy_conn(conn, false);
-       return io_close(conn);
 }
 
 static void accept_conn(struct io_listener *l)
 }
 
 static void accept_conn(struct io_listener *l)
@@ -270,7 +277,7 @@ void *io_loop(struct timers *timers, struct timer **expired)
                        }
                }
 
                        }
                }
 
-               r = poll(pollfds, num_fds, ms_timeout);
+               r = pollfn(pollfds, num_fds, ms_timeout);
                if (r < 0)
                        break;
 
                if (r < 0)
                        break;
 
@@ -282,9 +289,14 @@ void *io_loop(struct timers *timers, struct timer **expired)
                                break;
 
                        if (fds[i]->listener) {
                                break;
 
                        if (fds[i]->listener) {
+                               struct io_listener *l = (void *)fds[i];
                                if (events & POLLIN) {
                                if (events & POLLIN) {
-                                       accept_conn((void *)c);
+                                       accept_conn(l);
+                                       r--;
+                               } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
                                        r--;
                                        r--;
+                                       errno = EBADF;
+                                       io_close_listener(l);
                                }
                        } else if (events & (POLLIN|POLLOUT)) {
                                r--;
                                }
                        } else if (events & (POLLIN|POLLOUT)) {
                                r--;