]> git.ozlabs.org Git - ccan/blob - ccan/io/poll.c
fdpass: fix complilation on FreeBSD.
[ccan] / ccan / io / poll.c
1 /* Licensed under LGPLv2.1+ - see LICENSE file for details */
2 #include "io.h"
3 #include "backend.h"
4 #include <assert.h>
5 #include <poll.h>
6 #include <stdlib.h>
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <limits.h>
10 #include <errno.h>
11 #include <ccan/time/time.h>
12 #include <ccan/timer/timer.h>
13
14 static size_t num_fds = 0, max_fds = 0, num_waiting = 0, num_always = 0, max_always = 0, num_exclusive = 0;
15 static struct pollfd *pollfds = NULL;
16 static struct fd **fds = NULL;
17 static struct io_plan **always = NULL;
18 static struct timemono (*nowfn)(void) = time_mono;
19 static int (*pollfn)(struct pollfd *fds, nfds_t nfds, int timeout) = poll;
20
21 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
22 {
23         struct timemono (*old)(void) = nowfn;
24         nowfn = now;
25         return old;
26 }
27
28 int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int)
29 {
30         int (*old)(struct pollfd *fds, nfds_t nfds, int timeout) = pollfn;
31         pollfn = poll;
32         return old;
33 }
34
35 static bool add_fd(struct fd *fd, short events)
36 {
37         if (!max_fds) {
38                 assert(num_fds == 0);
39                 pollfds = tal_arr(NULL, struct pollfd, 8);
40                 if (!pollfds)
41                         return false;
42                 fds = tal_arr(pollfds, struct fd *, 8);
43                 if (!fds)
44                         return false;
45                 max_fds = 8;
46         }
47
48         if (num_fds + 1 > max_fds) {
49                 size_t num = max_fds * 2;
50
51                 if (!tal_resize(&pollfds, num))
52                         return false;
53                 if (!tal_resize(&fds, num))
54                         return false;
55                 max_fds = num;
56         }
57
58         pollfds[num_fds].events = events;
59         /* In case it's idle. */
60         if (!events)
61                 pollfds[num_fds].fd = -fd->fd - 1;
62         else
63                 pollfds[num_fds].fd = fd->fd;
64         pollfds[num_fds].revents = 0; /* In case we're iterating now */
65         fds[num_fds] = fd;
66         fd->backend_info = num_fds;
67         fd->exclusive[0] = fd->exclusive[1] = false;
68         num_fds++;
69         if (events)
70                 num_waiting++;
71
72         return true;
73 }
74
75 static void del_fd(struct fd *fd)
76 {
77         size_t n = fd->backend_info;
78
79         assert(n != -1);
80         assert(n < num_fds);
81         if (pollfds[n].events)
82                 num_waiting--;
83         if (n != num_fds - 1) {
84                 /* Move last one over us. */
85                 pollfds[n] = pollfds[num_fds-1];
86                 fds[n] = fds[num_fds-1];
87                 assert(fds[n]->backend_info == num_fds-1);
88                 fds[n]->backend_info = n;
89         } else if (num_fds == 1) {
90                 /* Free everything when no more fds. */
91                 pollfds = tal_free(pollfds);
92                 fds = NULL;
93                 max_fds = 0;
94                 if (num_always == 0) {
95                         always = tal_free(always);
96                         max_always = 0;
97                 }
98         }
99         num_fds--;
100         fd->backend_info = -1;
101
102         if (fd->exclusive[IO_IN])
103                 num_exclusive--;
104         if (fd->exclusive[IO_OUT])
105                 num_exclusive--;
106 }
107
108 static void destroy_listener(struct io_listener *l)
109 {
110         close(l->fd.fd);
111         del_fd(&l->fd);
112 }
113
114 bool add_listener(struct io_listener *l)
115 {
116         if (!add_fd(&l->fd, POLLIN))
117                 return false;
118         tal_add_destructor(l, destroy_listener);
119         return true;
120 }
121
122 static int find_always(const struct io_plan *plan)
123 {
124         size_t i = 0;
125
126         for (i = 0; i < num_always; i++)
127                 if (always[i] == plan)
128                         return i;
129         return -1;
130 }
131
132 static void remove_from_always(const struct io_plan *plan)
133 {
134         int pos;
135
136         if (plan->status != IO_ALWAYS)
137                 return;
138
139         pos = find_always(plan);
140         assert(pos >= 0);
141
142         /* Move last one down if we made a hole */
143         if (pos != num_always-1)
144                 always[pos] = always[num_always-1];
145         num_always--;
146
147         /* Only free if no fds left either. */
148         if (num_always == 0 && max_fds == 0) {
149                 always = tal_free(always);
150                 max_always = 0;
151         }
152 }
153
154 bool backend_new_always(struct io_plan *plan)
155 {
156         assert(find_always(plan) == -1);
157
158         if (!max_always) {
159                 assert(num_always == 0);
160                 always = tal_arr(NULL, struct io_plan *, 8);
161                 if (!always)
162                         return false;
163                 max_always = 8;
164         }
165
166         if (num_always + 1 > max_always) {
167                 size_t num = max_always * 2;
168
169                 if (!tal_resize(&always, num))
170                         return false;
171                 max_always = num;
172         }
173
174         always[num_always++] = plan;
175         return true;
176 }
177
178 static void setup_pfd(struct io_conn *conn, struct pollfd *pfd)
179 {
180         assert(pfd == &pollfds[conn->fd.backend_info]);
181
182         pfd->events = 0;
183         if (conn->plan[IO_IN].status == IO_POLLING_NOTSTARTED
184             || conn->plan[IO_IN].status == IO_POLLING_STARTED)
185                 pfd->events |= POLLIN;
186         if (conn->plan[IO_OUT].status == IO_POLLING_NOTSTARTED
187             || conn->plan[IO_OUT].status == IO_POLLING_STARTED)
188                 pfd->events |= POLLOUT;
189
190         if (pfd->events) {
191                 pfd->fd = conn->fd.fd;
192         } else {
193                 pfd->fd = -conn->fd.fd - 1;
194         }
195 }
196
197 void backend_new_plan(struct io_conn *conn)
198 {
199         struct pollfd *pfd = &pollfds[conn->fd.backend_info];
200
201         if (pfd->events)
202                 num_waiting--;
203
204         setup_pfd(conn, pfd);
205
206         if (pfd->events)
207                 num_waiting++;
208 }
209
210 void backend_wake(const void *wait)
211 {
212         unsigned int i;
213
214         for (i = 0; i < num_fds; i++) {
215                 struct io_conn *c;
216
217                 /* Ignore listeners */
218                 if (fds[i]->listener)
219                         continue;
220
221                 c = (void *)fds[i];
222                 if (c->plan[IO_IN].status == IO_WAITING
223                     && c->plan[IO_IN].arg.u1.const_vp == wait)
224                         io_do_wakeup(c, IO_IN);
225
226                 if (c->plan[IO_OUT].status == IO_WAITING
227                     && c->plan[IO_OUT].arg.u1.const_vp == wait)
228                         io_do_wakeup(c, IO_OUT);
229         }
230 }
231
232 static void destroy_conn(struct io_conn *conn, bool close_fd)
233 {
234         int saved_errno = errno;
235
236         if (close_fd)
237                 close(conn->fd.fd);
238         del_fd(&conn->fd);
239
240         remove_from_always(&conn->plan[IO_IN]);
241         remove_from_always(&conn->plan[IO_OUT]);
242
243         /* errno saved/restored by tal_free itself. */
244         if (conn->finish) {
245                 errno = saved_errno;
246                 conn->finish(conn, conn->finish_arg);
247         }
248 }
249
250 static void destroy_conn_close_fd(struct io_conn *conn)
251 {
252         destroy_conn(conn, true);
253 }
254
255 bool add_conn(struct io_conn *c)
256 {
257         if (!add_fd(&c->fd, 0))
258                 return false;
259         tal_add_destructor(c, destroy_conn_close_fd);
260         return true;
261 }
262
263 void cleanup_conn_without_close(struct io_conn *conn)
264 {
265         tal_del_destructor(conn, destroy_conn_close_fd);
266         destroy_conn(conn, false);
267 }
268
269 static void accept_conn(struct io_listener *l)
270 {
271         int fd = accept(l->fd.fd, NULL, NULL);
272
273         if (fd < 0) {
274                 /* If they've enabled it, this is how we tell them */
275                 if (io_get_extended_errors())
276                         l->init(NULL, l->arg);
277                 return;
278         }
279         io_new_conn(l->ctx, fd, l->init, l->arg);
280 }
281
282 /* Return pointer to exclusive flag for this plan. */
283 static bool *exclusive(struct io_plan *plan)
284 {
285         struct io_conn *conn;
286
287         conn = container_of(plan, struct io_conn, plan[plan->dir]);
288         return &conn->fd.exclusive[plan->dir];
289 }
290
291 /* For simplicity, we do one always at a time */
292 static bool handle_always(void)
293 {
294         int i;
295
296         /* Backwards is simple easier to remove entries */
297         for (i = num_always - 1; i >= 0; i--) {
298                 struct io_plan *plan = always[i];
299
300                 if (num_exclusive && !*exclusive(plan))
301                         continue;
302                 /* Remove first: it might re-add */
303                 if (i != num_always-1)
304                         always[i] = always[num_always-1];
305                 num_always--;
306                 io_do_always(plan);
307                 return true;
308         }
309
310         return false;
311 }
312
313 bool backend_set_exclusive(struct io_plan *plan, bool excl)
314 {
315         bool *excl_ptr = exclusive(plan);
316
317         if (excl != *excl_ptr) {
318                 *excl_ptr = excl;
319                 if (!excl)
320                         num_exclusive--;
321                 else
322                         num_exclusive++;
323         }
324
325         return num_exclusive != 0;
326 }
327
328 /* FIXME: We could do this once at set_exclusive time, and catch everywhere
329  * else that we manipulate events. */
330 static void exclude_pollfds(void)
331 {
332         size_t i;
333
334         if (num_exclusive == 0)
335                 return;
336
337         for (i = 0; i < num_fds; i++) {
338                 struct pollfd *pfd = &pollfds[fds[i]->backend_info];
339
340                 if (!fds[i]->exclusive[IO_IN])
341                         pfd->events &= ~POLLIN;
342                 if (!fds[i]->exclusive[IO_OUT])
343                         pfd->events &= ~POLLOUT;
344
345                 /* If we're not listening, we don't want error events
346                  * either. */
347                 if (!pfd->events)
348                         pfd->fd = -fds[i]->fd - 1;
349         }
350 }
351
352 static void restore_pollfds(void)
353 {
354         size_t i;
355
356         if (num_exclusive == 0)
357                 return;
358
359         for (i = 0; i < num_fds; i++) {
360                 struct pollfd *pfd = &pollfds[fds[i]->backend_info];
361
362                 if (fds[i]->listener) {
363                         pfd->events = POLLIN;
364                         pfd->fd = fds[i]->fd;
365                 } else {
366                         struct io_conn *conn = (void *)fds[i];
367                         setup_pfd(conn, pfd);
368                 }
369         }
370 }
371
372 /* This is the main loop. */
373 void *io_loop(struct timers *timers, struct timer **expired)
374 {
375         void *ret;
376
377         /* if timers is NULL, expired must be.  If not, not. */
378         assert(!timers == !expired);
379
380         /* Make sure this is NULL if we exit for some other reason. */
381         if (expired)
382                 *expired = NULL;
383
384         while (!io_loop_return) {
385                 int i, r, ms_timeout = -1;
386
387                 if (handle_always()) {
388                         /* Could have started/finished more. */
389                         continue;
390                 }
391
392                 /* Everything closed? */
393                 if (num_fds == 0)
394                         break;
395
396                 /* You can't tell them all to go to sleep! */
397                 assert(num_waiting);
398
399                 if (timers) {
400                         struct timemono now, first;
401
402                         now = nowfn();
403
404                         /* Call functions for expired timers. */
405                         *expired = timers_expire(timers, now);
406                         if (*expired)
407                                 break;
408
409                         /* Now figure out how long to wait for the next one. */
410                         if (timer_earliest(timers, &first)) {
411                                 uint64_t next;
412                                 next = time_to_msec(timemono_between(first, now));
413                                 if (next < INT_MAX)
414                                         ms_timeout = next;
415                                 else
416                                         ms_timeout = INT_MAX;
417                         }
418                 }
419
420                 /* We do this temporarily, assuming exclusive is unusual */
421                 exclude_pollfds();
422                 r = pollfn(pollfds, num_fds, ms_timeout);
423                 restore_pollfds();
424
425                 if (r < 0) {
426                         /* Signals shouldn't break us, unless they set
427                          * io_loop_return. */
428                         if (errno == EINTR)
429                                 continue;
430                         break;
431                 }
432
433                 for (i = 0; i < num_fds && !io_loop_return; i++) {
434                         struct io_conn *c = (void *)fds[i];
435                         int events = pollfds[i].revents;
436
437                         /* Clear so we don't get confused if exclusive next time */
438                         pollfds[i].revents = 0;
439
440                         if (r == 0)
441                                 break;
442
443                         if (fds[i]->listener) {
444                                 struct io_listener *l = (void *)fds[i];
445                                 if (events & POLLIN) {
446                                         accept_conn(l);
447                                         r--;
448                                 } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
449                                         r--;
450                                         errno = EBADF;
451                                         io_close_listener(l);
452                                 }
453                         } else if (events & (POLLIN|POLLOUT)) {
454                                 r--;
455                                 io_ready(c, events);
456                         } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
457                                 r--;
458                                 errno = EBADF;
459                                 io_close(c);
460                         }
461                 }
462         }
463
464         ret = io_loop_return;
465         io_loop_return = NULL;
466
467         return ret;
468 }
469
470 const void *io_have_fd(int fd, bool *listener)
471 {
472         for (size_t i = 0; i < num_fds; i++) {
473                 if (fds[i]->fd != fd)
474                         continue;
475                 if (listener)
476                         *listener = fds[i]->listener;
477                 return fds[i];
478         }
479         return NULL;
480 }