io: don't leak memory on clean shutdown.
[ccan] / ccan / io / poll.c
1 /* Licensed under LGPLv2.1+ - see LICENSE file for details */
2 #include "io.h"
3 #include "backend.h"
4 #include <assert.h>
5 #include <poll.h>
6 #include <stdlib.h>
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <limits.h>
10 #include <errno.h>
11 #include <ccan/time/time.h>
12 #include <ccan/timer/timer.h>
13
14 static size_t num_fds = 0, max_fds = 0, num_waiting = 0, num_always = 0, max_always = 0, num_exclusive = 0;
15 static struct pollfd *pollfds = NULL;
16 static struct fd **fds = NULL;
17 static struct io_plan **always = NULL;
18 static struct timemono (*nowfn)(void) = time_mono;
19 static int (*pollfn)(struct pollfd *fds, nfds_t nfds, int timeout) = poll;
20
21 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
22 {
23         struct timemono (*old)(void) = nowfn;
24         nowfn = now;
25         return old;
26 }
27
28 int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int)
29 {
30         int (*old)(struct pollfd *fds, nfds_t nfds, int timeout) = pollfn;
31         pollfn = poll;
32         return old;
33 }
34
35 static bool add_fd(struct fd *fd, short events)
36 {
37         if (!max_fds) {
38                 assert(num_fds == 0);
39                 pollfds = tal_arr(NULL, struct pollfd, 8);
40                 if (!pollfds)
41                         return false;
42                 fds = tal_arr(pollfds, struct fd *, 8);
43                 if (!fds)
44                         return false;
45                 max_fds = 8;
46         }
47
48         if (num_fds + 1 > max_fds) {
49                 size_t num = max_fds * 2;
50
51                 if (!tal_resize(&pollfds, num))
52                         return false;
53                 if (!tal_resize(&fds, num))
54                         return false;
55                 max_fds = num;
56         }
57
58         pollfds[num_fds].events = events;
59         /* In case it's idle. */
60         if (!events)
61                 pollfds[num_fds].fd = -fd->fd - 1;
62         else
63                 pollfds[num_fds].fd = fd->fd;
64         pollfds[num_fds].revents = 0; /* In case we're iterating now */
65         fds[num_fds] = fd;
66         fd->backend_info = num_fds;
67         fd->exclusive[0] = fd->exclusive[1] = false;
68         num_fds++;
69         if (events)
70                 num_waiting++;
71
72         return true;
73 }
74
75 static void del_fd(struct fd *fd)
76 {
77         size_t n = fd->backend_info;
78
79         assert(n != -1);
80         assert(n < num_fds);
81         if (pollfds[n].events)
82                 num_waiting--;
83         if (n != num_fds - 1) {
84                 /* Move last one over us. */
85                 pollfds[n] = pollfds[num_fds-1];
86                 fds[n] = fds[num_fds-1];
87                 assert(fds[n]->backend_info == num_fds-1);
88                 fds[n]->backend_info = n;
89         } else if (num_fds == 1) {
90                 /* Free everything when no more fds. */
91                 pollfds = tal_free(pollfds);
92                 fds = NULL;
93                 max_fds = 0;
94         }
95         num_fds--;
96         fd->backend_info = -1;
97
98         if (fd->exclusive[IO_IN])
99                 num_exclusive--;
100         if (fd->exclusive[IO_OUT])
101                 num_exclusive--;
102 }
103
104 static void destroy_listener(struct io_listener *l)
105 {
106         close(l->fd.fd);
107         del_fd(&l->fd);
108 }
109
110 bool add_listener(struct io_listener *l)
111 {
112         if (!add_fd(&l->fd, POLLIN))
113                 return false;
114         tal_add_destructor(l, destroy_listener);
115         return true;
116 }
117
118 static int find_always(const struct io_plan *plan)
119 {
120         for (size_t i = 0; i < num_always; i++)
121                 if (always[i] == plan)
122                         return i;
123         return -1;
124 }
125
126 static void remove_from_always(const struct io_plan *plan)
127 {
128         int pos;
129
130         if (plan->status != IO_ALWAYS)
131                 return;
132
133         pos = find_always(plan);
134         assert(pos >= 0);
135
136         /* Move last one down if we made a hole */
137         if (pos != num_always-1)
138                 always[pos] = always[num_always-1];
139         num_always--;
140
141         /* Only free if no fds left either. */
142         if (num_always == 0 && max_fds == 0) {
143                 tal_free(always);
144                 max_always = 0;
145         }
146 }
147
148 bool backend_new_always(struct io_plan *plan)
149 {
150         assert(find_always(plan) == -1);
151
152         if (!max_always) {
153                 assert(num_always == 0);
154                 always = tal_arr(NULL, struct io_plan *, 8);
155                 if (!always)
156                         return false;
157                 max_always = 8;
158         }
159
160         if (num_always + 1 > max_always) {
161                 size_t num = max_always * 2;
162
163                 if (!tal_resize(&always, num))
164                         return false;
165                 max_always = num;
166         }
167
168         always[num_always++] = plan;
169         return true;
170 }
171
172 static void setup_pfd(struct io_conn *conn, struct pollfd *pfd)
173 {
174         assert(pfd == &pollfds[conn->fd.backend_info]);
175
176         pfd->events = 0;
177         if (conn->plan[IO_IN].status == IO_POLLING_NOTSTARTED
178             || conn->plan[IO_IN].status == IO_POLLING_STARTED)
179                 pfd->events |= POLLIN;
180         if (conn->plan[IO_OUT].status == IO_POLLING_NOTSTARTED
181             || conn->plan[IO_OUT].status == IO_POLLING_STARTED)
182                 pfd->events |= POLLOUT;
183
184         if (pfd->events) {
185                 pfd->fd = conn->fd.fd;
186         } else {
187                 pfd->fd = -conn->fd.fd - 1;
188         }
189 }
190
191 void backend_new_plan(struct io_conn *conn)
192 {
193         struct pollfd *pfd = &pollfds[conn->fd.backend_info];
194
195         if (pfd->events)
196                 num_waiting--;
197
198         setup_pfd(conn, pfd);
199
200         if (pfd->events)
201                 num_waiting++;
202 }
203
204 void backend_wake(const void *wait)
205 {
206         unsigned int i;
207
208         for (i = 0; i < num_fds; i++) {
209                 struct io_conn *c;
210
211                 /* Ignore listeners */
212                 if (fds[i]->listener)
213                         continue;
214
215                 c = (void *)fds[i];
216                 if (c->plan[IO_IN].status == IO_WAITING
217                     && c->plan[IO_IN].arg.u1.const_vp == wait)
218                         io_do_wakeup(c, IO_IN);
219
220                 if (c->plan[IO_OUT].status == IO_WAITING
221                     && c->plan[IO_OUT].arg.u1.const_vp == wait)
222                         io_do_wakeup(c, IO_OUT);
223         }
224 }
225
226 static void destroy_conn(struct io_conn *conn, bool close_fd)
227 {
228         int saved_errno = errno;
229
230         if (close_fd)
231                 close(conn->fd.fd);
232         del_fd(&conn->fd);
233
234         remove_from_always(&conn->plan[IO_IN]);
235         remove_from_always(&conn->plan[IO_OUT]);
236
237         /* errno saved/restored by tal_free itself. */
238         if (conn->finish) {
239                 errno = saved_errno;
240                 conn->finish(conn, conn->finish_arg);
241         }
242 }
243
244 static void destroy_conn_close_fd(struct io_conn *conn)
245 {
246         destroy_conn(conn, true);
247 }
248
249 bool add_conn(struct io_conn *c)
250 {
251         if (!add_fd(&c->fd, 0))
252                 return false;
253         tal_add_destructor(c, destroy_conn_close_fd);
254         return true;
255 }
256
257 void cleanup_conn_without_close(struct io_conn *conn)
258 {
259         tal_del_destructor(conn, destroy_conn_close_fd);
260         destroy_conn(conn, false);
261 }
262
263 static void accept_conn(struct io_listener *l)
264 {
265         int fd = accept(l->fd.fd, NULL, NULL);
266
267         /* FIXME: What to do here? */
268         if (fd < 0)
269                 return;
270
271         io_new_conn(l->ctx, fd, l->init, l->arg);
272 }
273
274 /* Return pointer to exclusive flag for this plan. */
275 static bool *exclusive(struct io_plan *plan)
276 {
277         struct io_conn *conn;
278
279         conn = container_of(plan, struct io_conn, plan[plan->dir]);
280         return &conn->fd.exclusive[plan->dir];
281 }
282
283 /* For simplicity, we do one always at a time */
284 static bool handle_always(void)
285 {
286         /* Backwards is simple easier to remove entries */
287         for (int i = num_always - 1; i >= 0; i--) {
288                 struct io_plan *plan = always[i];
289
290                 if (num_exclusive && !*exclusive(plan))
291                         continue;
292                 /* Remove first: it might re-add */
293                 if (i != num_always-1)
294                         always[i] = always[num_always-1];
295                 num_always--;
296                 io_do_always(plan);
297                 return true;
298         }
299
300         return false;
301 }
302
303 bool backend_set_exclusive(struct io_plan *plan, bool excl)
304 {
305         bool *excl_ptr = exclusive(plan);
306
307         if (excl != *excl_ptr) {
308                 *excl_ptr = excl;
309                 if (!excl)
310                         num_exclusive--;
311                 else
312                         num_exclusive++;
313         }
314
315         return num_exclusive != 0;
316 }
317
318 /* FIXME: We could do this once at set_exclusive time, and catch everywhere
319  * else that we manipulate events. */
320 static void exclude_pollfds(void)
321 {
322         if (num_exclusive == 0)
323                 return;
324
325         for (size_t i = 0; i < num_fds; i++) {
326                 struct pollfd *pfd = &pollfds[fds[i]->backend_info];
327
328                 if (!fds[i]->exclusive[IO_IN])
329                         pfd->events &= ~POLLIN;
330                 if (!fds[i]->exclusive[IO_OUT])
331                         pfd->events &= ~POLLOUT;
332
333                 /* If we're not listening, we don't want error events
334                  * either. */
335                 if (!pfd->events)
336                         pfd->fd = -fds[i]->fd - 1;
337         }
338 }
339
340 static void restore_pollfds(void)
341 {
342         if (num_exclusive == 0)
343                 return;
344
345         for (size_t i = 0; i < num_fds; i++) {
346                 struct pollfd *pfd = &pollfds[fds[i]->backend_info];
347
348                 if (fds[i]->listener) {
349                         pfd->events = POLLIN;
350                         pfd->fd = fds[i]->fd;
351                 } else {
352                         struct io_conn *conn = (void *)fds[i];
353                         setup_pfd(conn, pfd);
354                 }
355         }
356 }
357
358 /* This is the main loop. */
359 void *io_loop(struct timers *timers, struct timer **expired)
360 {
361         void *ret;
362
363         /* if timers is NULL, expired must be.  If not, not. */
364         assert(!timers == !expired);
365
366         /* Make sure this is NULL if we exit for some other reason. */
367         if (expired)
368                 *expired = NULL;
369
370         while (!io_loop_return) {
371                 int i, r, ms_timeout = -1;
372
373                 if (handle_always()) {
374                         /* Could have started/finished more. */
375                         continue;
376                 }
377
378                 /* Everything closed? */
379                 if (num_fds == 0)
380                         break;
381
382                 /* You can't tell them all to go to sleep! */
383                 assert(num_waiting);
384
385                 if (timers) {
386                         struct timemono now, first;
387
388                         now = nowfn();
389
390                         /* Call functions for expired timers. */
391                         *expired = timers_expire(timers, now);
392                         if (*expired)
393                                 break;
394
395                         /* Now figure out how long to wait for the next one. */
396                         if (timer_earliest(timers, &first)) {
397                                 uint64_t next;
398                                 next = time_to_msec(timemono_between(first, now));
399                                 if (next < INT_MAX)
400                                         ms_timeout = next;
401                                 else
402                                         ms_timeout = INT_MAX;
403                         }
404                 }
405
406                 /* We do this temporarily, assuming exclusive is unusual */
407                 exclude_pollfds();
408                 r = pollfn(pollfds, num_fds, ms_timeout);
409                 restore_pollfds();
410
411                 if (r < 0) {
412                         /* Signals shouldn't break us, unless they set
413                          * io_loop_return. */
414                         if (errno == EINTR)
415                                 continue;
416                         break;
417                 }
418
419                 for (i = 0; i < num_fds && !io_loop_return; i++) {
420                         struct io_conn *c = (void *)fds[i];
421                         int events = pollfds[i].revents;
422
423                         /* Clear so we don't get confused if exclusive next time */
424                         pollfds[i].revents = 0;
425
426                         if (r == 0)
427                                 break;
428
429                         if (fds[i]->listener) {
430                                 struct io_listener *l = (void *)fds[i];
431                                 if (events & POLLIN) {
432                                         accept_conn(l);
433                                         r--;
434                                 } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
435                                         r--;
436                                         errno = EBADF;
437                                         io_close_listener(l);
438                                 }
439                         } else if (events & (POLLIN|POLLOUT)) {
440                                 r--;
441                                 io_ready(c, events);
442                         } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
443                                 r--;
444                                 errno = EBADF;
445                                 io_close(c);
446                         }
447                 }
448         }
449
450         ret = io_loop_return;
451         io_loop_return = NULL;
452
453         return ret;
454 }