]> git.ozlabs.org Git - ccan/blobdiff - ccan/io/poll.c
io: don't fail if we get a signal.
[ccan] / ccan / io / poll.c
index 229f7ce9ab4e6fb9bf69ccfd7ae5b650a4463613..b005a97e4b1d07879a4e281a8d48276baaac2fa1 100644 (file)
@@ -17,6 +17,7 @@ static struct fd **fds = NULL;
 static LIST_HEAD(closing);
 static LIST_HEAD(always);
 static struct timemono (*nowfn)(void) = time_mono;
+static int (*pollfn)(struct pollfd *fds, nfds_t nfds, int timeout) = poll;
 
 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
 {
@@ -25,6 +26,13 @@ struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
        return old;
 }
 
+int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int)
+{
+       int (*old)(struct pollfd *fds, nfds_t nfds, int timeout) = pollfn;
+       pollfn = poll;
+       return old;
+}
+
 static bool add_fd(struct fd *fd, short events)
 {
        if (!max_fds) {
@@ -122,9 +130,11 @@ void backend_new_plan(struct io_conn *conn)
                num_waiting--;
 
        pfd->events = 0;
-       if (conn->plan[IO_IN].status == IO_POLLING)
+       if (conn->plan[IO_IN].status == IO_POLLING_NOTSTARTED
+           || conn->plan[IO_IN].status == IO_POLLING_STARTED)
                pfd->events |= POLLIN;
-       if (conn->plan[IO_OUT].status == IO_POLLING)
+       if (conn->plan[IO_OUT].status == IO_POLLING_NOTSTARTED
+           || conn->plan[IO_OUT].status == IO_POLLING_STARTED)
                pfd->events |= POLLOUT;
 
        if (pfd->events) {
@@ -187,11 +197,10 @@ bool add_conn(struct io_conn *c)
        return true;
 }
 
-struct io_plan *io_close_taken_fd(struct io_conn *conn)
+void cleanup_conn_without_close(struct io_conn *conn)
 {
        tal_del_destructor(conn, destroy_conn_close_fd);
        destroy_conn(conn, false);
-       return io_close(conn);
 }
 
 static void accept_conn(struct io_listener *l)
@@ -270,9 +279,14 @@ void *io_loop(struct timers *timers, struct timer **expired)
                        }
                }
 
-               r = poll(pollfds, num_fds, ms_timeout);
-               if (r < 0)
+               r = pollfn(pollfds, num_fds, ms_timeout);
+               if (r < 0) {
+                       /* Signals shouldn't break us, unless they set
+                        * io_loop_return. */
+                       if (errno == EINTR)
+                               continue;
                        break;
+               }
 
                for (i = 0; i < num_fds && !io_loop_return; i++) {
                        struct io_conn *c = (void *)fds[i];