]> git.ozlabs.org Git - ccan/blob - ccan/io/poll.c
base64: fix for unsigned chars (e.g. ARM).
[ccan] / ccan / io / poll.c
1 /* Licensed under LGPLv2.1+ - see LICENSE file for details */
2 #include "io.h"
3 #include "backend.h"
4 #include <assert.h>
5 #include <poll.h>
6 #include <stdlib.h>
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <limits.h>
10 #include <errno.h>
11 #include <ccan/time/time.h>
12 #include <ccan/timer/timer.h>
13
14 static size_t num_fds = 0, max_fds = 0, num_waiting = 0, num_always = 0, max_always = 0, num_exclusive = 0;
15 static struct pollfd *pollfds = NULL;
16 static struct fd **fds = NULL;
17 static struct io_plan **always = NULL;
18 static struct timemono (*nowfn)(void) = time_mono;
19 static int (*pollfn)(struct pollfd *fds, nfds_t nfds, int timeout) = poll;
20
21 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
22 {
23         struct timemono (*old)(void) = nowfn;
24         nowfn = now;
25         return old;
26 }
27
28 int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int)
29 {
30         int (*old)(struct pollfd *fds, nfds_t nfds, int timeout) = pollfn;
31         pollfn = poll;
32         return old;
33 }
34
35 static bool add_fd(struct fd *fd, short events)
36 {
37         if (!max_fds) {
38                 assert(num_fds == 0);
39                 pollfds = tal_arr(NULL, struct pollfd, 8);
40                 if (!pollfds)
41                         return false;
42                 fds = tal_arr(pollfds, struct fd *, 8);
43                 if (!fds)
44                         return false;
45                 max_fds = 8;
46         }
47
48         if (num_fds + 1 > max_fds) {
49                 size_t num = max_fds * 2;
50
51                 if (!tal_resize(&pollfds, num))
52                         return false;
53                 if (!tal_resize(&fds, num))
54                         return false;
55                 max_fds = num;
56         }
57
58         pollfds[num_fds].events = events;
59         /* In case it's idle. */
60         if (!events)
61                 pollfds[num_fds].fd = -fd->fd - 1;
62         else
63                 pollfds[num_fds].fd = fd->fd;
64         pollfds[num_fds].revents = 0; /* In case we're iterating now */
65         fds[num_fds] = fd;
66         fd->backend_info = num_fds;
67         fd->exclusive[0] = fd->exclusive[1] = false;
68         num_fds++;
69         if (events)
70                 num_waiting++;
71
72         return true;
73 }
74
75 static void del_fd(struct fd *fd)
76 {
77         size_t n = fd->backend_info;
78
79         assert(n != -1);
80         assert(n < num_fds);
81         if (pollfds[n].events)
82                 num_waiting--;
83         if (n != num_fds - 1) {
84                 /* Move last one over us. */
85                 pollfds[n] = pollfds[num_fds-1];
86                 fds[n] = fds[num_fds-1];
87                 assert(fds[n]->backend_info == num_fds-1);
88                 fds[n]->backend_info = n;
89         } else if (num_fds == 1) {
90                 /* Free everything when no more fds. */
91                 pollfds = tal_free(pollfds);
92                 fds = NULL;
93                 max_fds = 0;
94                 if (num_always == 0) {
95                         always = tal_free(always);
96                         max_always = 0;
97                 }
98         }
99         num_fds--;
100         fd->backend_info = -1;
101
102         if (fd->exclusive[IO_IN])
103                 num_exclusive--;
104         if (fd->exclusive[IO_OUT])
105                 num_exclusive--;
106 }
107
108 static void destroy_listener(struct io_listener *l)
109 {
110         close(l->fd.fd);
111         del_fd(&l->fd);
112 }
113
114 bool add_listener(struct io_listener *l)
115 {
116         if (!add_fd(&l->fd, POLLIN))
117                 return false;
118         tal_add_destructor(l, destroy_listener);
119         return true;
120 }
121
122 static int find_always(const struct io_plan *plan)
123 {
124         size_t i = 0;
125
126         for (i = 0; i < num_always; i++)
127                 if (always[i] == plan)
128                         return i;
129         return -1;
130 }
131
132 static void remove_from_always(const struct io_plan *plan)
133 {
134         int pos;
135
136         if (plan->status != IO_ALWAYS)
137                 return;
138
139         pos = find_always(plan);
140         assert(pos >= 0);
141
142         /* Move last one down if we made a hole */
143         if (pos != num_always-1)
144                 always[pos] = always[num_always-1];
145         num_always--;
146
147         /* Only free if no fds left either. */
148         if (num_always == 0 && max_fds == 0) {
149                 always = tal_free(always);
150                 max_always = 0;
151         }
152 }
153
154 bool backend_new_always(struct io_plan *plan)
155 {
156         assert(find_always(plan) == -1);
157
158         if (!max_always) {
159                 assert(num_always == 0);
160                 always = tal_arr(NULL, struct io_plan *, 8);
161                 if (!always)
162                         return false;
163                 max_always = 8;
164         }
165
166         if (num_always + 1 > max_always) {
167                 size_t num = max_always * 2;
168
169                 if (!tal_resize(&always, num))
170                         return false;
171                 max_always = num;
172         }
173
174         always[num_always++] = plan;
175         return true;
176 }
177
178 static void setup_pfd(struct io_conn *conn, struct pollfd *pfd)
179 {
180         assert(pfd == &pollfds[conn->fd.backend_info]);
181
182         pfd->events = 0;
183         if (conn->plan[IO_IN].status == IO_POLLING_NOTSTARTED
184             || conn->plan[IO_IN].status == IO_POLLING_STARTED)
185                 pfd->events |= POLLIN;
186         if (conn->plan[IO_OUT].status == IO_POLLING_NOTSTARTED
187             || conn->plan[IO_OUT].status == IO_POLLING_STARTED)
188                 pfd->events |= POLLOUT;
189
190         if (pfd->events) {
191                 pfd->fd = conn->fd.fd;
192         } else {
193                 pfd->fd = -conn->fd.fd - 1;
194         }
195 }
196
197 void backend_new_plan(struct io_conn *conn)
198 {
199         struct pollfd *pfd = &pollfds[conn->fd.backend_info];
200
201         if (pfd->events)
202                 num_waiting--;
203
204         setup_pfd(conn, pfd);
205
206         if (pfd->events)
207                 num_waiting++;
208 }
209
210 void backend_wake(const void *wait)
211 {
212         unsigned int i;
213
214         for (i = 0; i < num_fds; i++) {
215                 struct io_conn *c;
216
217                 /* Ignore listeners */
218                 if (fds[i]->listener)
219                         continue;
220
221                 c = (void *)fds[i];
222                 if (c->plan[IO_IN].status == IO_WAITING
223                     && c->plan[IO_IN].arg.u1.const_vp == wait)
224                         io_do_wakeup(c, IO_IN);
225
226                 if (c->plan[IO_OUT].status == IO_WAITING
227                     && c->plan[IO_OUT].arg.u1.const_vp == wait)
228                         io_do_wakeup(c, IO_OUT);
229         }
230 }
231
232 static void destroy_conn(struct io_conn *conn, bool close_fd)
233 {
234         int saved_errno = errno;
235
236         if (close_fd)
237                 close(conn->fd.fd);
238         del_fd(&conn->fd);
239
240         remove_from_always(&conn->plan[IO_IN]);
241         remove_from_always(&conn->plan[IO_OUT]);
242
243         /* errno saved/restored by tal_free itself. */
244         if (conn->finish) {
245                 errno = saved_errno;
246                 conn->finish(conn, conn->finish_arg);
247         }
248 }
249
250 static void destroy_conn_close_fd(struct io_conn *conn)
251 {
252         destroy_conn(conn, true);
253 }
254
255 bool add_conn(struct io_conn *c)
256 {
257         if (!add_fd(&c->fd, 0))
258                 return false;
259         tal_add_destructor(c, destroy_conn_close_fd);
260         return true;
261 }
262
263 void cleanup_conn_without_close(struct io_conn *conn)
264 {
265         tal_del_destructor(conn, destroy_conn_close_fd);
266         destroy_conn(conn, false);
267 }
268
269 static void accept_conn(struct io_listener *l)
270 {
271         int fd = accept(l->fd.fd, NULL, NULL);
272
273         /* FIXME: What to do here? */
274         if (fd < 0)
275                 return;
276
277         io_new_conn(l->ctx, fd, l->init, l->arg);
278 }
279
280 /* Return pointer to exclusive flag for this plan. */
281 static bool *exclusive(struct io_plan *plan)
282 {
283         struct io_conn *conn;
284
285         conn = container_of(plan, struct io_conn, plan[plan->dir]);
286         return &conn->fd.exclusive[plan->dir];
287 }
288
289 /* For simplicity, we do one always at a time */
290 static bool handle_always(void)
291 {
292         int i;
293
294         /* Backwards is simple easier to remove entries */
295         for (i = num_always - 1; i >= 0; i--) {
296                 struct io_plan *plan = always[i];
297
298                 if (num_exclusive && !*exclusive(plan))
299                         continue;
300                 /* Remove first: it might re-add */
301                 if (i != num_always-1)
302                         always[i] = always[num_always-1];
303                 num_always--;
304                 io_do_always(plan);
305                 return true;
306         }
307
308         return false;
309 }
310
311 bool backend_set_exclusive(struct io_plan *plan, bool excl)
312 {
313         bool *excl_ptr = exclusive(plan);
314
315         if (excl != *excl_ptr) {
316                 *excl_ptr = excl;
317                 if (!excl)
318                         num_exclusive--;
319                 else
320                         num_exclusive++;
321         }
322
323         return num_exclusive != 0;
324 }
325
326 /* FIXME: We could do this once at set_exclusive time, and catch everywhere
327  * else that we manipulate events. */
328 static void exclude_pollfds(void)
329 {
330         size_t i;
331
332         if (num_exclusive == 0)
333                 return;
334
335         for (i = 0; i < num_fds; i++) {
336                 struct pollfd *pfd = &pollfds[fds[i]->backend_info];
337
338                 if (!fds[i]->exclusive[IO_IN])
339                         pfd->events &= ~POLLIN;
340                 if (!fds[i]->exclusive[IO_OUT])
341                         pfd->events &= ~POLLOUT;
342
343                 /* If we're not listening, we don't want error events
344                  * either. */
345                 if (!pfd->events)
346                         pfd->fd = -fds[i]->fd - 1;
347         }
348 }
349
350 static void restore_pollfds(void)
351 {
352         size_t i;
353
354         if (num_exclusive == 0)
355                 return;
356
357         for (i = 0; i < num_fds; i++) {
358                 struct pollfd *pfd = &pollfds[fds[i]->backend_info];
359
360                 if (fds[i]->listener) {
361                         pfd->events = POLLIN;
362                         pfd->fd = fds[i]->fd;
363                 } else {
364                         struct io_conn *conn = (void *)fds[i];
365                         setup_pfd(conn, pfd);
366                 }
367         }
368 }
369
370 /* This is the main loop. */
371 void *io_loop(struct timers *timers, struct timer **expired)
372 {
373         void *ret;
374
375         /* if timers is NULL, expired must be.  If not, not. */
376         assert(!timers == !expired);
377
378         /* Make sure this is NULL if we exit for some other reason. */
379         if (expired)
380                 *expired = NULL;
381
382         while (!io_loop_return) {
383                 int i, r, ms_timeout = -1;
384
385                 if (handle_always()) {
386                         /* Could have started/finished more. */
387                         continue;
388                 }
389
390                 /* Everything closed? */
391                 if (num_fds == 0)
392                         break;
393
394                 /* You can't tell them all to go to sleep! */
395                 assert(num_waiting);
396
397                 if (timers) {
398                         struct timemono now, first;
399
400                         now = nowfn();
401
402                         /* Call functions for expired timers. */
403                         *expired = timers_expire(timers, now);
404                         if (*expired)
405                                 break;
406
407                         /* Now figure out how long to wait for the next one. */
408                         if (timer_earliest(timers, &first)) {
409                                 uint64_t next;
410                                 next = time_to_msec(timemono_between(first, now));
411                                 if (next < INT_MAX)
412                                         ms_timeout = next;
413                                 else
414                                         ms_timeout = INT_MAX;
415                         }
416                 }
417
418                 /* We do this temporarily, assuming exclusive is unusual */
419                 exclude_pollfds();
420                 r = pollfn(pollfds, num_fds, ms_timeout);
421                 restore_pollfds();
422
423                 if (r < 0) {
424                         /* Signals shouldn't break us, unless they set
425                          * io_loop_return. */
426                         if (errno == EINTR)
427                                 continue;
428                         break;
429                 }
430
431                 for (i = 0; i < num_fds && !io_loop_return; i++) {
432                         struct io_conn *c = (void *)fds[i];
433                         int events = pollfds[i].revents;
434
435                         /* Clear so we don't get confused if exclusive next time */
436                         pollfds[i].revents = 0;
437
438                         if (r == 0)
439                                 break;
440
441                         if (fds[i]->listener) {
442                                 struct io_listener *l = (void *)fds[i];
443                                 if (events & POLLIN) {
444                                         accept_conn(l);
445                                         r--;
446                                 } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
447                                         r--;
448                                         errno = EBADF;
449                                         io_close_listener(l);
450                                 }
451                         } else if (events & (POLLIN|POLLOUT)) {
452                                 r--;
453                                 io_ready(c, events);
454                         } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
455                                 r--;
456                                 errno = EBADF;
457                                 io_close(c);
458                         }
459                 }
460         }
461
462         ret = io_loop_return;
463         io_loop_return = NULL;
464
465         return ret;
466 }
467
468 const void *io_have_fd(int fd, bool *listener)
469 {
470         for (size_t i = 0; i < num_fds; i++) {
471                 if (fds[i]->fd != fd)
472                         continue;
473                 if (listener)
474                         *listener = fds[i]->listener;
475                 return fds[i];
476         }
477         return NULL;
478 }