a853c8705f1d5b68971327ec9c38d22ef4222e2b
[ccan] / ccan / io / poll.c
1 /* Licensed under LGPLv2.1+ - see LICENSE file for details */
2 #include "io.h"
3 #include "backend.h"
4 #include <assert.h>
5 #include <poll.h>
6 #include <stdlib.h>
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <limits.h>
10 #include <errno.h>
11 #include <ccan/time/time.h>
12 #include <ccan/timer/timer.h>
13
14 static size_t num_fds = 0, max_fds = 0, num_waiting = 0, num_always = 0, max_always = 0;
15 static struct pollfd *pollfds = NULL;
16 static struct fd **fds = NULL;
17 static struct io_plan **always = NULL;
18 static struct timemono (*nowfn)(void) = time_mono;
19 static int (*pollfn)(struct pollfd *fds, nfds_t nfds, int timeout) = poll;
20
21 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
22 {
23         struct timemono (*old)(void) = nowfn;
24         nowfn = now;
25         return old;
26 }
27
28 int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int)
29 {
30         int (*old)(struct pollfd *fds, nfds_t nfds, int timeout) = pollfn;
31         pollfn = poll;
32         return old;
33 }
34
35 static bool add_fd(struct fd *fd, short events)
36 {
37         if (!max_fds) {
38                 assert(num_fds == 0);
39                 pollfds = tal_arr(NULL, struct pollfd, 8);
40                 if (!pollfds)
41                         return false;
42                 fds = tal_arr(pollfds, struct fd *, 8);
43                 if (!fds)
44                         return false;
45                 max_fds = 8;
46         }
47
48         if (num_fds + 1 > max_fds) {
49                 size_t num = max_fds * 2;
50
51                 if (!tal_resize(&pollfds, num))
52                         return false;
53                 if (!tal_resize(&fds, num))
54                         return false;
55                 max_fds = num;
56         }
57
58         pollfds[num_fds].events = events;
59         /* In case it's idle. */
60         if (!events)
61                 pollfds[num_fds].fd = -fd->fd;
62         else
63                 pollfds[num_fds].fd = fd->fd;
64         pollfds[num_fds].revents = 0; /* In case we're iterating now */
65         fds[num_fds] = fd;
66         fd->backend_info = num_fds;
67         num_fds++;
68         if (events)
69                 num_waiting++;
70
71         return true;
72 }
73
74 static void del_fd(struct fd *fd)
75 {
76         size_t n = fd->backend_info;
77
78         assert(n != -1);
79         assert(n < num_fds);
80         if (pollfds[n].events)
81                 num_waiting--;
82         if (n != num_fds - 1) {
83                 /* Move last one over us. */
84                 pollfds[n] = pollfds[num_fds-1];
85                 fds[n] = fds[num_fds-1];
86                 assert(fds[n]->backend_info == num_fds-1);
87                 fds[n]->backend_info = n;
88         } else if (num_fds == 1) {
89                 /* Free everything when no more fds. */
90                 pollfds = tal_free(pollfds);
91                 fds = NULL;
92                 max_fds = 0;
93         }
94         num_fds--;
95         fd->backend_info = -1;
96 }
97
98 static void destroy_listener(struct io_listener *l)
99 {
100         close(l->fd.fd);
101         del_fd(&l->fd);
102 }
103
104 bool add_listener(struct io_listener *l)
105 {
106         if (!add_fd(&l->fd, POLLIN))
107                 return false;
108         tal_add_destructor(l, destroy_listener);
109         return true;
110 }
111
112 static int find_always(const struct io_plan *plan)
113 {
114         for (size_t i = 0; i < num_always; i++)
115                 if (always[i] == plan)
116                         return i;
117         return -1;
118 }
119
120 static void remove_from_always(const struct io_plan *plan)
121 {
122         int pos;
123
124         if (plan->status != IO_ALWAYS)
125                 return;
126
127         pos = find_always(plan);
128         assert(pos >= 0);
129
130         /* Move last one down if we made a hole */
131         if (pos != num_always-1)
132                 always[pos] = always[num_always-1];
133         num_always--;
134 }
135
136 bool backend_new_always(struct io_plan *plan)
137 {
138         assert(find_always(plan) == -1);
139
140         if (!max_always) {
141                 assert(num_always == 0);
142                 always = tal_arr(NULL, struct io_plan *, 8);
143                 if (!always)
144                         return false;
145                 max_always = 8;
146         }
147
148         if (num_always + 1 > max_always) {
149                 size_t num = max_always * 2;
150
151                 if (!tal_resize(&always, num))
152                         return false;
153                 max_always = num;
154         }
155
156         always[num_always++] = plan;
157         return true;
158 }
159
160 void backend_new_plan(struct io_conn *conn)
161 {
162         struct pollfd *pfd = &pollfds[conn->fd.backend_info];
163
164         if (pfd->events)
165                 num_waiting--;
166
167         pfd->events = 0;
168         if (conn->plan[IO_IN].status == IO_POLLING_NOTSTARTED
169             || conn->plan[IO_IN].status == IO_POLLING_STARTED)
170                 pfd->events |= POLLIN;
171         if (conn->plan[IO_OUT].status == IO_POLLING_NOTSTARTED
172             || conn->plan[IO_OUT].status == IO_POLLING_STARTED)
173                 pfd->events |= POLLOUT;
174
175         if (pfd->events) {
176                 num_waiting++;
177                 pfd->fd = conn->fd.fd;
178         } else {
179                 pfd->fd = -conn->fd.fd;
180         }
181 }
182
183 void backend_wake(const void *wait)
184 {
185         unsigned int i;
186
187         for (i = 0; i < num_fds; i++) {
188                 struct io_conn *c;
189
190                 /* Ignore listeners */
191                 if (fds[i]->listener)
192                         continue;
193
194                 c = (void *)fds[i];
195                 if (c->plan[IO_IN].status == IO_WAITING
196                     && c->plan[IO_IN].arg.u1.const_vp == wait)
197                         io_do_wakeup(c, IO_IN);
198
199                 if (c->plan[IO_OUT].status == IO_WAITING
200                     && c->plan[IO_OUT].arg.u1.const_vp == wait)
201                         io_do_wakeup(c, IO_OUT);
202         }
203 }
204
205 static void destroy_conn(struct io_conn *conn, bool close_fd)
206 {
207         int saved_errno = errno;
208
209         if (close_fd)
210                 close(conn->fd.fd);
211         del_fd(&conn->fd);
212
213         remove_from_always(&conn->plan[IO_IN]);
214         remove_from_always(&conn->plan[IO_OUT]);
215
216         /* errno saved/restored by tal_free itself. */
217         if (conn->finish) {
218                 errno = saved_errno;
219                 conn->finish(conn, conn->finish_arg);
220         }
221 }
222
223 static void destroy_conn_close_fd(struct io_conn *conn)
224 {
225         destroy_conn(conn, true);
226 }
227
228 bool add_conn(struct io_conn *c)
229 {
230         if (!add_fd(&c->fd, 0))
231                 return false;
232         tal_add_destructor(c, destroy_conn_close_fd);
233         return true;
234 }
235
236 void cleanup_conn_without_close(struct io_conn *conn)
237 {
238         tal_del_destructor(conn, destroy_conn_close_fd);
239         destroy_conn(conn, false);
240 }
241
242 static void accept_conn(struct io_listener *l)
243 {
244         int fd = accept(l->fd.fd, NULL, NULL);
245
246         /* FIXME: What to do here? */
247         if (fd < 0)
248                 return;
249
250         io_new_conn(l->ctx, fd, l->init, l->arg);
251 }
252
253 static bool handle_always(void)
254 {
255         bool ret = false;
256
257         while (num_always > 0) {
258                 /* Remove first: it might re-add */
259                 struct io_plan *plan = always[num_always-1];
260                 num_always--;
261                 io_do_always(plan);
262                 ret = true;
263         }
264         return ret;
265 }
266
267 /* This is the main loop. */
268 void *io_loop(struct timers *timers, struct timer **expired)
269 {
270         void *ret;
271
272         /* if timers is NULL, expired must be.  If not, not. */
273         assert(!timers == !expired);
274
275         /* Make sure this is NULL if we exit for some other reason. */
276         if (expired)
277                 *expired = NULL;
278
279         while (!io_loop_return) {
280                 int i, r, ms_timeout = -1;
281
282                 if (handle_always()) {
283                         /* Could have started/finished more. */
284                         continue;
285                 }
286
287                 /* Everything closed? */
288                 if (num_fds == 0)
289                         break;
290
291                 /* You can't tell them all to go to sleep! */
292                 assert(num_waiting);
293
294                 if (timers) {
295                         struct timemono now, first;
296
297                         now = nowfn();
298
299                         /* Call functions for expired timers. */
300                         *expired = timers_expire(timers, now);
301                         if (*expired)
302                                 break;
303
304                         /* Now figure out how long to wait for the next one. */
305                         if (timer_earliest(timers, &first)) {
306                                 uint64_t next;
307                                 next = time_to_msec(timemono_between(first, now));
308                                 if (next < INT_MAX)
309                                         ms_timeout = next;
310                                 else
311                                         ms_timeout = INT_MAX;
312                         }
313                 }
314
315                 r = pollfn(pollfds, num_fds, ms_timeout);
316                 if (r < 0) {
317                         /* Signals shouldn't break us, unless they set
318                          * io_loop_return. */
319                         if (errno == EINTR)
320                                 continue;
321                         break;
322                 }
323
324                 for (i = 0; i < num_fds && !io_loop_return; i++) {
325                         struct io_conn *c = (void *)fds[i];
326                         int events = pollfds[i].revents;
327
328                         if (r == 0)
329                                 break;
330
331                         if (fds[i]->listener) {
332                                 struct io_listener *l = (void *)fds[i];
333                                 if (events & POLLIN) {
334                                         accept_conn(l);
335                                         r--;
336                                 } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
337                                         r--;
338                                         errno = EBADF;
339                                         io_close_listener(l);
340                                 }
341                         } else if (events & (POLLIN|POLLOUT)) {
342                                 r--;
343                                 io_ready(c, events);
344                         } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
345                                 r--;
346                                 errno = EBADF;
347                                 io_close(c);
348                         }
349                 }
350         }
351
352         ret = io_loop_return;
353         io_loop_return = NULL;
354
355         return ret;
356 }