]> git.ozlabs.org Git - ccan/blob - ccan/io/poll.c
io: fix another leak path for always array.
[ccan] / ccan / io / poll.c
1 /* Licensed under LGPLv2.1+ - see LICENSE file for details */
2 #include "io.h"
3 #include "backend.h"
4 #include <assert.h>
5 #include <poll.h>
6 #include <stdlib.h>
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <limits.h>
10 #include <errno.h>
11 #include <ccan/time/time.h>
12 #include <ccan/timer/timer.h>
13
14 static size_t num_fds = 0, max_fds = 0, num_waiting = 0, num_always = 0, max_always = 0, num_exclusive = 0;
15 static struct pollfd *pollfds = NULL;
16 static struct fd **fds = NULL;
17 static struct io_plan **always = NULL;
18 static struct timemono (*nowfn)(void) = time_mono;
19 static int (*pollfn)(struct pollfd *fds, nfds_t nfds, int timeout) = poll;
20
21 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
22 {
23         struct timemono (*old)(void) = nowfn;
24         nowfn = now;
25         return old;
26 }
27
28 int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int)
29 {
30         int (*old)(struct pollfd *fds, nfds_t nfds, int timeout) = pollfn;
31         pollfn = poll;
32         return old;
33 }
34
35 static bool add_fd(struct fd *fd, short events)
36 {
37         if (!max_fds) {
38                 assert(num_fds == 0);
39                 pollfds = tal_arr(NULL, struct pollfd, 8);
40                 if (!pollfds)
41                         return false;
42                 fds = tal_arr(pollfds, struct fd *, 8);
43                 if (!fds)
44                         return false;
45                 max_fds = 8;
46         }
47
48         if (num_fds + 1 > max_fds) {
49                 size_t num = max_fds * 2;
50
51                 if (!tal_resize(&pollfds, num))
52                         return false;
53                 if (!tal_resize(&fds, num))
54                         return false;
55                 max_fds = num;
56         }
57
58         pollfds[num_fds].events = events;
59         /* In case it's idle. */
60         if (!events)
61                 pollfds[num_fds].fd = -fd->fd - 1;
62         else
63                 pollfds[num_fds].fd = fd->fd;
64         pollfds[num_fds].revents = 0; /* In case we're iterating now */
65         fds[num_fds] = fd;
66         fd->backend_info = num_fds;
67         fd->exclusive[0] = fd->exclusive[1] = false;
68         num_fds++;
69         if (events)
70                 num_waiting++;
71
72         return true;
73 }
74
75 static void del_fd(struct fd *fd)
76 {
77         size_t n = fd->backend_info;
78
79         assert(n != -1);
80         assert(n < num_fds);
81         if (pollfds[n].events)
82                 num_waiting--;
83         if (n != num_fds - 1) {
84                 /* Move last one over us. */
85                 pollfds[n] = pollfds[num_fds-1];
86                 fds[n] = fds[num_fds-1];
87                 assert(fds[n]->backend_info == num_fds-1);
88                 fds[n]->backend_info = n;
89         } else if (num_fds == 1) {
90                 /* Free everything when no more fds. */
91                 pollfds = tal_free(pollfds);
92                 fds = NULL;
93                 max_fds = 0;
94                 if (num_always == 0) {
95                         always = tal_free(always);
96                         max_always = 0;
97                 }
98         }
99         num_fds--;
100         fd->backend_info = -1;
101
102         if (fd->exclusive[IO_IN])
103                 num_exclusive--;
104         if (fd->exclusive[IO_OUT])
105                 num_exclusive--;
106 }
107
108 static void destroy_listener(struct io_listener *l)
109 {
110         close(l->fd.fd);
111         del_fd(&l->fd);
112 }
113
114 bool add_listener(struct io_listener *l)
115 {
116         if (!add_fd(&l->fd, POLLIN))
117                 return false;
118         tal_add_destructor(l, destroy_listener);
119         return true;
120 }
121
122 static int find_always(const struct io_plan *plan)
123 {
124         for (size_t i = 0; i < num_always; i++)
125                 if (always[i] == plan)
126                         return i;
127         return -1;
128 }
129
130 static void remove_from_always(const struct io_plan *plan)
131 {
132         int pos;
133
134         if (plan->status != IO_ALWAYS)
135                 return;
136
137         pos = find_always(plan);
138         assert(pos >= 0);
139
140         /* Move last one down if we made a hole */
141         if (pos != num_always-1)
142                 always[pos] = always[num_always-1];
143         num_always--;
144
145         /* Only free if no fds left either. */
146         if (num_always == 0 && max_fds == 0) {
147                 always = tal_free(always);
148                 max_always = 0;
149         }
150 }
151
152 bool backend_new_always(struct io_plan *plan)
153 {
154         assert(find_always(plan) == -1);
155
156         if (!max_always) {
157                 assert(num_always == 0);
158                 always = tal_arr(NULL, struct io_plan *, 8);
159                 if (!always)
160                         return false;
161                 max_always = 8;
162         }
163
164         if (num_always + 1 > max_always) {
165                 size_t num = max_always * 2;
166
167                 if (!tal_resize(&always, num))
168                         return false;
169                 max_always = num;
170         }
171
172         always[num_always++] = plan;
173         return true;
174 }
175
176 static void setup_pfd(struct io_conn *conn, struct pollfd *pfd)
177 {
178         assert(pfd == &pollfds[conn->fd.backend_info]);
179
180         pfd->events = 0;
181         if (conn->plan[IO_IN].status == IO_POLLING_NOTSTARTED
182             || conn->plan[IO_IN].status == IO_POLLING_STARTED)
183                 pfd->events |= POLLIN;
184         if (conn->plan[IO_OUT].status == IO_POLLING_NOTSTARTED
185             || conn->plan[IO_OUT].status == IO_POLLING_STARTED)
186                 pfd->events |= POLLOUT;
187
188         if (pfd->events) {
189                 pfd->fd = conn->fd.fd;
190         } else {
191                 pfd->fd = -conn->fd.fd - 1;
192         }
193 }
194
195 void backend_new_plan(struct io_conn *conn)
196 {
197         struct pollfd *pfd = &pollfds[conn->fd.backend_info];
198
199         if (pfd->events)
200                 num_waiting--;
201
202         setup_pfd(conn, pfd);
203
204         if (pfd->events)
205                 num_waiting++;
206 }
207
208 void backend_wake(const void *wait)
209 {
210         unsigned int i;
211
212         for (i = 0; i < num_fds; i++) {
213                 struct io_conn *c;
214
215                 /* Ignore listeners */
216                 if (fds[i]->listener)
217                         continue;
218
219                 c = (void *)fds[i];
220                 if (c->plan[IO_IN].status == IO_WAITING
221                     && c->plan[IO_IN].arg.u1.const_vp == wait)
222                         io_do_wakeup(c, IO_IN);
223
224                 if (c->plan[IO_OUT].status == IO_WAITING
225                     && c->plan[IO_OUT].arg.u1.const_vp == wait)
226                         io_do_wakeup(c, IO_OUT);
227         }
228 }
229
230 static void destroy_conn(struct io_conn *conn, bool close_fd)
231 {
232         int saved_errno = errno;
233
234         if (close_fd)
235                 close(conn->fd.fd);
236         del_fd(&conn->fd);
237
238         remove_from_always(&conn->plan[IO_IN]);
239         remove_from_always(&conn->plan[IO_OUT]);
240
241         /* errno saved/restored by tal_free itself. */
242         if (conn->finish) {
243                 errno = saved_errno;
244                 conn->finish(conn, conn->finish_arg);
245         }
246 }
247
248 static void destroy_conn_close_fd(struct io_conn *conn)
249 {
250         destroy_conn(conn, true);
251 }
252
253 bool add_conn(struct io_conn *c)
254 {
255         if (!add_fd(&c->fd, 0))
256                 return false;
257         tal_add_destructor(c, destroy_conn_close_fd);
258         return true;
259 }
260
261 void cleanup_conn_without_close(struct io_conn *conn)
262 {
263         tal_del_destructor(conn, destroy_conn_close_fd);
264         destroy_conn(conn, false);
265 }
266
267 static void accept_conn(struct io_listener *l)
268 {
269         int fd = accept(l->fd.fd, NULL, NULL);
270
271         /* FIXME: What to do here? */
272         if (fd < 0)
273                 return;
274
275         io_new_conn(l->ctx, fd, l->init, l->arg);
276 }
277
278 /* Return pointer to exclusive flag for this plan. */
279 static bool *exclusive(struct io_plan *plan)
280 {
281         struct io_conn *conn;
282
283         conn = container_of(plan, struct io_conn, plan[plan->dir]);
284         return &conn->fd.exclusive[plan->dir];
285 }
286
287 /* For simplicity, we do one always at a time */
288 static bool handle_always(void)
289 {
290         /* Backwards is simple easier to remove entries */
291         for (int i = num_always - 1; i >= 0; i--) {
292                 struct io_plan *plan = always[i];
293
294                 if (num_exclusive && !*exclusive(plan))
295                         continue;
296                 /* Remove first: it might re-add */
297                 if (i != num_always-1)
298                         always[i] = always[num_always-1];
299                 num_always--;
300                 io_do_always(plan);
301                 return true;
302         }
303
304         return false;
305 }
306
307 bool backend_set_exclusive(struct io_plan *plan, bool excl)
308 {
309         bool *excl_ptr = exclusive(plan);
310
311         if (excl != *excl_ptr) {
312                 *excl_ptr = excl;
313                 if (!excl)
314                         num_exclusive--;
315                 else
316                         num_exclusive++;
317         }
318
319         return num_exclusive != 0;
320 }
321
322 /* FIXME: We could do this once at set_exclusive time, and catch everywhere
323  * else that we manipulate events. */
324 static void exclude_pollfds(void)
325 {
326         if (num_exclusive == 0)
327                 return;
328
329         for (size_t i = 0; i < num_fds; i++) {
330                 struct pollfd *pfd = &pollfds[fds[i]->backend_info];
331
332                 if (!fds[i]->exclusive[IO_IN])
333                         pfd->events &= ~POLLIN;
334                 if (!fds[i]->exclusive[IO_OUT])
335                         pfd->events &= ~POLLOUT;
336
337                 /* If we're not listening, we don't want error events
338                  * either. */
339                 if (!pfd->events)
340                         pfd->fd = -fds[i]->fd - 1;
341         }
342 }
343
344 static void restore_pollfds(void)
345 {
346         if (num_exclusive == 0)
347                 return;
348
349         for (size_t i = 0; i < num_fds; i++) {
350                 struct pollfd *pfd = &pollfds[fds[i]->backend_info];
351
352                 if (fds[i]->listener) {
353                         pfd->events = POLLIN;
354                         pfd->fd = fds[i]->fd;
355                 } else {
356                         struct io_conn *conn = (void *)fds[i];
357                         setup_pfd(conn, pfd);
358                 }
359         }
360 }
361
362 /* This is the main loop. */
363 void *io_loop(struct timers *timers, struct timer **expired)
364 {
365         void *ret;
366
367         /* if timers is NULL, expired must be.  If not, not. */
368         assert(!timers == !expired);
369
370         /* Make sure this is NULL if we exit for some other reason. */
371         if (expired)
372                 *expired = NULL;
373
374         while (!io_loop_return) {
375                 int i, r, ms_timeout = -1;
376
377                 if (handle_always()) {
378                         /* Could have started/finished more. */
379                         continue;
380                 }
381
382                 /* Everything closed? */
383                 if (num_fds == 0)
384                         break;
385
386                 /* You can't tell them all to go to sleep! */
387                 assert(num_waiting);
388
389                 if (timers) {
390                         struct timemono now, first;
391
392                         now = nowfn();
393
394                         /* Call functions for expired timers. */
395                         *expired = timers_expire(timers, now);
396                         if (*expired)
397                                 break;
398
399                         /* Now figure out how long to wait for the next one. */
400                         if (timer_earliest(timers, &first)) {
401                                 uint64_t next;
402                                 next = time_to_msec(timemono_between(first, now));
403                                 if (next < INT_MAX)
404                                         ms_timeout = next;
405                                 else
406                                         ms_timeout = INT_MAX;
407                         }
408                 }
409
410                 /* We do this temporarily, assuming exclusive is unusual */
411                 exclude_pollfds();
412                 r = pollfn(pollfds, num_fds, ms_timeout);
413                 restore_pollfds();
414
415                 if (r < 0) {
416                         /* Signals shouldn't break us, unless they set
417                          * io_loop_return. */
418                         if (errno == EINTR)
419                                 continue;
420                         break;
421                 }
422
423                 for (i = 0; i < num_fds && !io_loop_return; i++) {
424                         struct io_conn *c = (void *)fds[i];
425                         int events = pollfds[i].revents;
426
427                         /* Clear so we don't get confused if exclusive next time */
428                         pollfds[i].revents = 0;
429
430                         if (r == 0)
431                                 break;
432
433                         if (fds[i]->listener) {
434                                 struct io_listener *l = (void *)fds[i];
435                                 if (events & POLLIN) {
436                                         accept_conn(l);
437                                         r--;
438                                 } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
439                                         r--;
440                                         errno = EBADF;
441                                         io_close_listener(l);
442                                 }
443                         } else if (events & (POLLIN|POLLOUT)) {
444                                 r--;
445                                 io_ready(c, events);
446                         } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
447                                 r--;
448                                 errno = EBADF;
449                                 io_close(c);
450                         }
451                 }
452         }
453
454         ret = io_loop_return;
455         io_loop_return = NULL;
456
457         return ret;
458 }