]> git.ozlabs.org Git - ccan/blob - ccan/io/poll.c
95b6103287d6335b0b84e4475a7b8469a2d33b54
[ccan] / ccan / io / poll.c
1 /* Licensed under LGPLv2.1+ - see LICENSE file for details */
2 #include "io.h"
3 #include "backend.h"
4 #include <assert.h>
5 #include <poll.h>
6 #include <stdlib.h>
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <limits.h>
10 #include <errno.h>
11 #include <ccan/time/time.h>
12 #include <ccan/timer/timer.h>
13
14 static size_t num_fds = 0, max_fds = 0, num_waiting = 0, num_always = 0, max_always = 0, num_exclusive = 0;
15 static struct pollfd *pollfds = NULL;
16 static struct fd **fds = NULL;
17 static struct io_plan **always = NULL;
18 static struct timemono (*nowfn)(void) = time_mono;
19 static int (*pollfn)(struct pollfd *fds, nfds_t nfds, int timeout) = poll;
20
21 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
22 {
23         struct timemono (*old)(void) = nowfn;
24         nowfn = now;
25         return old;
26 }
27
28 int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int)
29 {
30         int (*old)(struct pollfd *fds, nfds_t nfds, int timeout) = pollfn;
31         pollfn = poll;
32         return old;
33 }
34
35 static bool add_fd(struct fd *fd, short events)
36 {
37         if (!max_fds) {
38                 assert(num_fds == 0);
39                 pollfds = tal_arr(NULL, struct pollfd, 8);
40                 if (!pollfds)
41                         return false;
42                 fds = tal_arr(pollfds, struct fd *, 8);
43                 if (!fds)
44                         return false;
45                 max_fds = 8;
46         }
47
48         if (num_fds + 1 > max_fds) {
49                 size_t num = max_fds * 2;
50
51                 if (!tal_resize(&pollfds, num))
52                         return false;
53                 if (!tal_resize(&fds, num))
54                         return false;
55                 max_fds = num;
56         }
57
58         pollfds[num_fds].events = events;
59         /* In case it's idle. */
60         if (!events)
61                 pollfds[num_fds].fd = -fd->fd - 1;
62         else
63                 pollfds[num_fds].fd = fd->fd;
64         pollfds[num_fds].revents = 0; /* In case we're iterating now */
65         fds[num_fds] = fd;
66         fd->backend_info = num_fds;
67         fd->exclusive[0] = fd->exclusive[1] = false;
68         num_fds++;
69         if (events)
70                 num_waiting++;
71
72         return true;
73 }
74
75 static void del_fd(struct fd *fd)
76 {
77         size_t n = fd->backend_info;
78
79         assert(n != -1);
80         assert(n < num_fds);
81         if (pollfds[n].events)
82                 num_waiting--;
83         if (n != num_fds - 1) {
84                 /* Move last one over us. */
85                 pollfds[n] = pollfds[num_fds-1];
86                 fds[n] = fds[num_fds-1];
87                 assert(fds[n]->backend_info == num_fds-1);
88                 fds[n]->backend_info = n;
89         } else if (num_fds == 1) {
90                 /* Free everything when no more fds. */
91                 pollfds = tal_free(pollfds);
92                 fds = NULL;
93                 max_fds = 0;
94         }
95         num_fds--;
96         fd->backend_info = -1;
97
98         if (fd->exclusive[IO_IN])
99                 num_exclusive--;
100         if (fd->exclusive[IO_OUT])
101                 num_exclusive--;
102 }
103
104 static void destroy_listener(struct io_listener *l)
105 {
106         close(l->fd.fd);
107         del_fd(&l->fd);
108 }
109
110 bool add_listener(struct io_listener *l)
111 {
112         if (!add_fd(&l->fd, POLLIN))
113                 return false;
114         tal_add_destructor(l, destroy_listener);
115         return true;
116 }
117
118 static int find_always(const struct io_plan *plan)
119 {
120         for (size_t i = 0; i < num_always; i++)
121                 if (always[i] == plan)
122                         return i;
123         return -1;
124 }
125
126 static void remove_from_always(const struct io_plan *plan)
127 {
128         int pos;
129
130         if (plan->status != IO_ALWAYS)
131                 return;
132
133         pos = find_always(plan);
134         assert(pos >= 0);
135
136         /* Move last one down if we made a hole */
137         if (pos != num_always-1)
138                 always[pos] = always[num_always-1];
139         num_always--;
140 }
141
142 bool backend_new_always(struct io_plan *plan)
143 {
144         assert(find_always(plan) == -1);
145
146         if (!max_always) {
147                 assert(num_always == 0);
148                 always = tal_arr(NULL, struct io_plan *, 8);
149                 if (!always)
150                         return false;
151                 max_always = 8;
152         }
153
154         if (num_always + 1 > max_always) {
155                 size_t num = max_always * 2;
156
157                 if (!tal_resize(&always, num))
158                         return false;
159                 max_always = num;
160         }
161
162         always[num_always++] = plan;
163         return true;
164 }
165
166 static void setup_pfd(struct io_conn *conn, struct pollfd *pfd)
167 {
168         assert(pfd == &pollfds[conn->fd.backend_info]);
169
170         pfd->events = 0;
171         if (conn->plan[IO_IN].status == IO_POLLING_NOTSTARTED
172             || conn->plan[IO_IN].status == IO_POLLING_STARTED)
173                 pfd->events |= POLLIN;
174         if (conn->plan[IO_OUT].status == IO_POLLING_NOTSTARTED
175             || conn->plan[IO_OUT].status == IO_POLLING_STARTED)
176                 pfd->events |= POLLOUT;
177
178         if (pfd->events) {
179                 pfd->fd = conn->fd.fd;
180         } else {
181                 pfd->fd = -conn->fd.fd - 1;
182         }
183 }
184
185 void backend_new_plan(struct io_conn *conn)
186 {
187         struct pollfd *pfd = &pollfds[conn->fd.backend_info];
188
189         if (pfd->events)
190                 num_waiting--;
191
192         setup_pfd(conn, pfd);
193
194         if (pfd->events)
195                 num_waiting++;
196 }
197
198 void backend_wake(const void *wait)
199 {
200         unsigned int i;
201
202         for (i = 0; i < num_fds; i++) {
203                 struct io_conn *c;
204
205                 /* Ignore listeners */
206                 if (fds[i]->listener)
207                         continue;
208
209                 c = (void *)fds[i];
210                 if (c->plan[IO_IN].status == IO_WAITING
211                     && c->plan[IO_IN].arg.u1.const_vp == wait)
212                         io_do_wakeup(c, IO_IN);
213
214                 if (c->plan[IO_OUT].status == IO_WAITING
215                     && c->plan[IO_OUT].arg.u1.const_vp == wait)
216                         io_do_wakeup(c, IO_OUT);
217         }
218 }
219
220 static void destroy_conn(struct io_conn *conn, bool close_fd)
221 {
222         int saved_errno = errno;
223
224         if (close_fd)
225                 close(conn->fd.fd);
226         del_fd(&conn->fd);
227
228         remove_from_always(&conn->plan[IO_IN]);
229         remove_from_always(&conn->plan[IO_OUT]);
230
231         /* errno saved/restored by tal_free itself. */
232         if (conn->finish) {
233                 errno = saved_errno;
234                 conn->finish(conn, conn->finish_arg);
235         }
236 }
237
238 static void destroy_conn_close_fd(struct io_conn *conn)
239 {
240         destroy_conn(conn, true);
241 }
242
243 bool add_conn(struct io_conn *c)
244 {
245         if (!add_fd(&c->fd, 0))
246                 return false;
247         tal_add_destructor(c, destroy_conn_close_fd);
248         return true;
249 }
250
251 void cleanup_conn_without_close(struct io_conn *conn)
252 {
253         tal_del_destructor(conn, destroy_conn_close_fd);
254         destroy_conn(conn, false);
255 }
256
257 static void accept_conn(struct io_listener *l)
258 {
259         int fd = accept(l->fd.fd, NULL, NULL);
260
261         /* FIXME: What to do here? */
262         if (fd < 0)
263                 return;
264
265         io_new_conn(l->ctx, fd, l->init, l->arg);
266 }
267
268 /* Return pointer to exclusive flag for this plan. */
269 static bool *exclusive(struct io_plan *plan)
270 {
271         struct io_conn *conn;
272
273         conn = container_of(plan, struct io_conn, plan[plan->dir]);
274         return &conn->fd.exclusive[plan->dir];
275 }
276
277 /* For simplicity, we do one always at a time */
278 static bool handle_always(void)
279 {
280         /* Backwards is simple easier to remove entries */
281         for (int i = num_always - 1; i >= 0; i--) {
282                 struct io_plan *plan = always[i];
283
284                 if (num_exclusive && !*exclusive(plan))
285                         continue;
286                 /* Remove first: it might re-add */
287                 if (i != num_always-1)
288                         always[i] = always[num_always-1];
289                 num_always--;
290                 io_do_always(plan);
291                 return true;
292         }
293
294         return false;
295 }
296
297 bool backend_set_exclusive(struct io_plan *plan, bool excl)
298 {
299         bool *excl_ptr = exclusive(plan);
300
301         if (excl != *excl_ptr) {
302                 *excl_ptr = excl;
303                 if (!excl)
304                         num_exclusive--;
305                 else
306                         num_exclusive++;
307         }
308
309         return num_exclusive != 0;
310 }
311
312 /* FIXME: We could do this once at set_exclusive time, and catch everywhere
313  * else that we manipulate events. */
314 static void exclude_pollfds(void)
315 {
316         if (num_exclusive == 0)
317                 return;
318
319         for (size_t i = 0; i < num_fds; i++) {
320                 struct pollfd *pfd = &pollfds[fds[i]->backend_info];
321
322                 if (!fds[i]->exclusive[IO_IN])
323                         pfd->events &= ~POLLIN;
324                 if (!fds[i]->exclusive[IO_OUT])
325                         pfd->events &= ~POLLOUT;
326
327                 /* If we're not listening, we don't want error events
328                  * either. */
329                 if (!pfd->events)
330                         pfd->fd = -fds[i]->fd - 1;
331         }
332 }
333
334 static void restore_pollfds(void)
335 {
336         if (num_exclusive == 0)
337                 return;
338
339         for (size_t i = 0; i < num_fds; i++) {
340                 struct pollfd *pfd = &pollfds[fds[i]->backend_info];
341
342                 if (fds[i]->listener) {
343                         pfd->events = POLLIN;
344                         pfd->fd = fds[i]->fd;
345                 } else {
346                         struct io_conn *conn = (void *)fds[i];
347                         setup_pfd(conn, pfd);
348                 }
349         }
350 }
351
352 /* This is the main loop. */
353 void *io_loop(struct timers *timers, struct timer **expired)
354 {
355         void *ret;
356
357         /* if timers is NULL, expired must be.  If not, not. */
358         assert(!timers == !expired);
359
360         /* Make sure this is NULL if we exit for some other reason. */
361         if (expired)
362                 *expired = NULL;
363
364         while (!io_loop_return) {
365                 int i, r, ms_timeout = -1;
366
367                 if (handle_always()) {
368                         /* Could have started/finished more. */
369                         continue;
370                 }
371
372                 /* Everything closed? */
373                 if (num_fds == 0)
374                         break;
375
376                 /* You can't tell them all to go to sleep! */
377                 assert(num_waiting);
378
379                 if (timers) {
380                         struct timemono now, first;
381
382                         now = nowfn();
383
384                         /* Call functions for expired timers. */
385                         *expired = timers_expire(timers, now);
386                         if (*expired)
387                                 break;
388
389                         /* Now figure out how long to wait for the next one. */
390                         if (timer_earliest(timers, &first)) {
391                                 uint64_t next;
392                                 next = time_to_msec(timemono_between(first, now));
393                                 if (next < INT_MAX)
394                                         ms_timeout = next;
395                                 else
396                                         ms_timeout = INT_MAX;
397                         }
398                 }
399
400                 /* We do this temporarily, assuming exclusive is unusual */
401                 exclude_pollfds();
402                 r = pollfn(pollfds, num_fds, ms_timeout);
403                 restore_pollfds();
404
405                 if (r < 0) {
406                         /* Signals shouldn't break us, unless they set
407                          * io_loop_return. */
408                         if (errno == EINTR)
409                                 continue;
410                         break;
411                 }
412
413                 for (i = 0; i < num_fds && !io_loop_return; i++) {
414                         struct io_conn *c = (void *)fds[i];
415                         int events = pollfds[i].revents;
416
417                         /* Clear so we don't get confused if exclusive next time */
418                         pollfds[i].revents = 0;
419
420                         if (r == 0)
421                                 break;
422
423                         if (fds[i]->listener) {
424                                 struct io_listener *l = (void *)fds[i];
425                                 if (events & POLLIN) {
426                                         accept_conn(l);
427                                         r--;
428                                 } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
429                                         r--;
430                                         errno = EBADF;
431                                         io_close_listener(l);
432                                 }
433                         } else if (events & (POLLIN|POLLOUT)) {
434                                 r--;
435                                 io_ready(c, events);
436                         } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
437                                 r--;
438                                 errno = EBADF;
439                                 io_close(c);
440                         }
441                 }
442         }
443
444         ret = io_loop_return;
445         io_loop_return = NULL;
446
447         return ret;
448 }