io: allow overriding poll function.
[ccan] / ccan / io / poll.c
1 /* Licensed under LGPLv2.1+ - see LICENSE file for details */
2 #include "io.h"
3 #include "backend.h"
4 #include <assert.h>
5 #include <poll.h>
6 #include <stdlib.h>
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <limits.h>
10 #include <errno.h>
11 #include <ccan/time/time.h>
12 #include <ccan/timer/timer.h>
13
14 static size_t num_fds = 0, max_fds = 0, num_waiting = 0;
15 static struct pollfd *pollfds = NULL;
16 static struct fd **fds = NULL;
17 static LIST_HEAD(closing);
18 static LIST_HEAD(always);
19 static struct timemono (*nowfn)(void) = time_mono;
20 static int (*pollfn)(struct pollfd *fds, nfds_t nfds, int timeout) = poll;
21
22 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
23 {
24         struct timemono (*old)(void) = nowfn;
25         nowfn = now;
26         return old;
27 }
28
29 int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int)
30 {
31         int (*old)(struct pollfd *fds, nfds_t nfds, int timeout) = pollfn;
32         pollfn = poll;
33         return old;
34 }
35
36 static bool add_fd(struct fd *fd, short events)
37 {
38         if (!max_fds) {
39                 assert(num_fds == 0);
40                 pollfds = tal_arr(NULL, struct pollfd, 8);
41                 if (!pollfds)
42                         return false;
43                 fds = tal_arr(pollfds, struct fd *, 8);
44                 if (!fds)
45                         return false;
46                 max_fds = 8;
47         }
48
49         if (num_fds + 1 > max_fds) {
50                 size_t num = max_fds * 2;
51
52                 if (!tal_resize(&pollfds, num))
53                         return false;
54                 if (!tal_resize(&fds, num))
55                         return false;
56                 max_fds = num;
57         }
58
59         pollfds[num_fds].events = events;
60         /* In case it's idle. */
61         if (!events)
62                 pollfds[num_fds].fd = -fd->fd;
63         else
64                 pollfds[num_fds].fd = fd->fd;
65         pollfds[num_fds].revents = 0; /* In case we're iterating now */
66         fds[num_fds] = fd;
67         fd->backend_info = num_fds;
68         num_fds++;
69         if (events)
70                 num_waiting++;
71
72         return true;
73 }
74
75 static void del_fd(struct fd *fd)
76 {
77         size_t n = fd->backend_info;
78
79         assert(n != -1);
80         assert(n < num_fds);
81         if (pollfds[n].events)
82                 num_waiting--;
83         if (n != num_fds - 1) {
84                 /* Move last one over us. */
85                 pollfds[n] = pollfds[num_fds-1];
86                 fds[n] = fds[num_fds-1];
87                 assert(fds[n]->backend_info == num_fds-1);
88                 fds[n]->backend_info = n;
89         } else if (num_fds == 1) {
90                 /* Free everything when no more fds. */
91                 pollfds = tal_free(pollfds);
92                 fds = NULL;
93                 max_fds = 0;
94         }
95         num_fds--;
96         fd->backend_info = -1;
97 }
98
99 static void destroy_listener(struct io_listener *l)
100 {
101         close(l->fd.fd);
102         del_fd(&l->fd);
103 }
104
105 bool add_listener(struct io_listener *l)
106 {
107         if (!add_fd(&l->fd, POLLIN))
108                 return false;
109         tal_add_destructor(l, destroy_listener);
110         return true;
111 }
112
113 void remove_from_always(struct io_conn *conn)
114 {
115         list_del_init(&conn->always);
116 }
117
118 void backend_new_always(struct io_conn *conn)
119 {
120         /* In case it's already in always list. */
121         list_del(&conn->always);
122         list_add_tail(&always, &conn->always);
123 }
124
125 void backend_new_plan(struct io_conn *conn)
126 {
127         struct pollfd *pfd = &pollfds[conn->fd.backend_info];
128
129         if (pfd->events)
130                 num_waiting--;
131
132         pfd->events = 0;
133         if (conn->plan[IO_IN].status == IO_POLLING)
134                 pfd->events |= POLLIN;
135         if (conn->plan[IO_OUT].status == IO_POLLING)
136                 pfd->events |= POLLOUT;
137
138         if (pfd->events) {
139                 num_waiting++;
140                 pfd->fd = conn->fd.fd;
141         } else {
142                 pfd->fd = -conn->fd.fd;
143         }
144 }
145
146 void backend_wake(const void *wait)
147 {
148         unsigned int i;
149
150         for (i = 0; i < num_fds; i++) {
151                 struct io_conn *c;
152
153                 /* Ignore listeners */
154                 if (fds[i]->listener)
155                         continue;
156
157                 c = (void *)fds[i];
158                 if (c->plan[IO_IN].status == IO_WAITING
159                     && c->plan[IO_IN].arg.u1.const_vp == wait)
160                         io_do_wakeup(c, IO_IN);
161
162                 if (c->plan[IO_OUT].status == IO_WAITING
163                     && c->plan[IO_OUT].arg.u1.const_vp == wait)
164                         io_do_wakeup(c, IO_OUT);
165         }
166 }
167
168 static void destroy_conn(struct io_conn *conn, bool close_fd)
169 {
170         int saved_errno = errno;
171
172         if (close_fd)
173                 close(conn->fd.fd);
174         del_fd(&conn->fd);
175         /* In case it's on always list, remove it. */
176         list_del_init(&conn->always);
177
178         /* errno saved/restored by tal_free itself. */
179         if (conn->finish) {
180                 errno = saved_errno;
181                 conn->finish(conn, conn->finish_arg);
182         }
183 }
184
185 static void destroy_conn_close_fd(struct io_conn *conn)
186 {
187         destroy_conn(conn, true);
188 }
189
190 bool add_conn(struct io_conn *c)
191 {
192         if (!add_fd(&c->fd, 0))
193                 return false;
194         tal_add_destructor(c, destroy_conn_close_fd);
195         return true;
196 }
197
198 void cleanup_conn_without_close(struct io_conn *conn)
199 {
200         tal_del_destructor(conn, destroy_conn_close_fd);
201         destroy_conn(conn, false);
202 }
203
204 static void accept_conn(struct io_listener *l)
205 {
206         int fd = accept(l->fd.fd, NULL, NULL);
207
208         /* FIXME: What to do here? */
209         if (fd < 0)
210                 return;
211
212         io_new_conn(l->ctx, fd, l->init, l->arg);
213 }
214
215 static bool handle_always(void)
216 {
217         bool ret = false;
218         struct io_conn *conn;
219
220         while ((conn = list_pop(&always, struct io_conn, always)) != NULL) {
221                 assert(conn->plan[IO_IN].status == IO_ALWAYS
222                        || conn->plan[IO_OUT].status == IO_ALWAYS);
223
224                 /* Re-initialize, for next time. */
225                 list_node_init(&conn->always);
226                 io_do_always(conn);
227                 ret = true;
228         }
229         return ret;
230 }
231
232 /* This is the main loop. */
233 void *io_loop(struct timers *timers, struct timer **expired)
234 {
235         void *ret;
236
237         /* if timers is NULL, expired must be.  If not, not. */
238         assert(!timers == !expired);
239
240         /* Make sure this is NULL if we exit for some other reason. */
241         if (expired)
242                 *expired = NULL;
243
244         while (!io_loop_return) {
245                 int i, r, ms_timeout = -1;
246
247                 if (handle_always()) {
248                         /* Could have started/finished more. */
249                         continue;
250                 }
251
252                 /* Everything closed? */
253                 if (num_fds == 0)
254                         break;
255
256                 /* You can't tell them all to go to sleep! */
257                 assert(num_waiting);
258
259                 if (timers) {
260                         struct timemono now, first;
261
262                         now = nowfn();
263
264                         /* Call functions for expired timers. */
265                         *expired = timers_expire(timers, now);
266                         if (*expired)
267                                 break;
268
269                         /* Now figure out how long to wait for the next one. */
270                         if (timer_earliest(timers, &first)) {
271                                 uint64_t next;
272                                 next = time_to_msec(timemono_between(first, now));
273                                 if (next < INT_MAX)
274                                         ms_timeout = next;
275                                 else
276                                         ms_timeout = INT_MAX;
277                         }
278                 }
279
280                 r = pollfn(pollfds, num_fds, ms_timeout);
281                 if (r < 0)
282                         break;
283
284                 for (i = 0; i < num_fds && !io_loop_return; i++) {
285                         struct io_conn *c = (void *)fds[i];
286                         int events = pollfds[i].revents;
287
288                         if (r == 0)
289                                 break;
290
291                         if (fds[i]->listener) {
292                                 struct io_listener *l = (void *)fds[i];
293                                 if (events & POLLIN) {
294                                         accept_conn(l);
295                                         r--;
296                                 } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
297                                         r--;
298                                         errno = EBADF;
299                                         io_close_listener(l);
300                                 }
301                         } else if (events & (POLLIN|POLLOUT)) {
302                                 r--;
303                                 io_ready(c, events);
304                         } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
305                                 r--;
306                                 errno = EBADF;
307                                 io_close(c);
308                         }
309                 }
310         }
311
312         ret = io_loop_return;
313         io_loop_return = NULL;
314
315         return ret;
316 }