]> git.ozlabs.org Git - ccan/blob - ccan/io/poll.c
io: handle errors on listening file descriptors.
[ccan] / ccan / io / poll.c
1 /* Licensed under LGPLv2.1+ - see LICENSE file for details */
2 #include "io.h"
3 #include "backend.h"
4 #include <assert.h>
5 #include <poll.h>
6 #include <stdlib.h>
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <limits.h>
10 #include <errno.h>
11 #include <ccan/time/time.h>
12 #include <ccan/timer/timer.h>
13
14 static size_t num_fds = 0, max_fds = 0, num_waiting = 0;
15 static struct pollfd *pollfds = NULL;
16 static struct fd **fds = NULL;
17 static LIST_HEAD(closing);
18 static LIST_HEAD(always);
19 static struct timemono (*nowfn)(void) = time_mono;
20
21 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
22 {
23         struct timemono (*old)(void) = nowfn;
24         nowfn = now;
25         return old;
26 }
27
28 static bool add_fd(struct fd *fd, short events)
29 {
30         if (!max_fds) {
31                 assert(num_fds == 0);
32                 pollfds = tal_arr(NULL, struct pollfd, 8);
33                 if (!pollfds)
34                         return false;
35                 fds = tal_arr(pollfds, struct fd *, 8);
36                 if (!fds)
37                         return false;
38                 max_fds = 8;
39         }
40
41         if (num_fds + 1 > max_fds) {
42                 size_t num = max_fds * 2;
43
44                 if (!tal_resize(&pollfds, num))
45                         return false;
46                 if (!tal_resize(&fds, num))
47                         return false;
48                 max_fds = num;
49         }
50
51         pollfds[num_fds].events = events;
52         /* In case it's idle. */
53         if (!events)
54                 pollfds[num_fds].fd = -fd->fd;
55         else
56                 pollfds[num_fds].fd = fd->fd;
57         pollfds[num_fds].revents = 0; /* In case we're iterating now */
58         fds[num_fds] = fd;
59         fd->backend_info = num_fds;
60         num_fds++;
61         if (events)
62                 num_waiting++;
63
64         return true;
65 }
66
67 static void del_fd(struct fd *fd)
68 {
69         size_t n = fd->backend_info;
70
71         assert(n != -1);
72         assert(n < num_fds);
73         if (pollfds[n].events)
74                 num_waiting--;
75         if (n != num_fds - 1) {
76                 /* Move last one over us. */
77                 pollfds[n] = pollfds[num_fds-1];
78                 fds[n] = fds[num_fds-1];
79                 assert(fds[n]->backend_info == num_fds-1);
80                 fds[n]->backend_info = n;
81         } else if (num_fds == 1) {
82                 /* Free everything when no more fds. */
83                 pollfds = tal_free(pollfds);
84                 fds = NULL;
85                 max_fds = 0;
86         }
87         num_fds--;
88         fd->backend_info = -1;
89 }
90
91 static void destroy_listener(struct io_listener *l)
92 {
93         close(l->fd.fd);
94         del_fd(&l->fd);
95 }
96
97 bool add_listener(struct io_listener *l)
98 {
99         if (!add_fd(&l->fd, POLLIN))
100                 return false;
101         tal_add_destructor(l, destroy_listener);
102         return true;
103 }
104
105 void remove_from_always(struct io_conn *conn)
106 {
107         list_del_init(&conn->always);
108 }
109
110 void backend_new_always(struct io_conn *conn)
111 {
112         /* In case it's already in always list. */
113         list_del(&conn->always);
114         list_add_tail(&always, &conn->always);
115 }
116
117 void backend_new_plan(struct io_conn *conn)
118 {
119         struct pollfd *pfd = &pollfds[conn->fd.backend_info];
120
121         if (pfd->events)
122                 num_waiting--;
123
124         pfd->events = 0;
125         if (conn->plan[IO_IN].status == IO_POLLING)
126                 pfd->events |= POLLIN;
127         if (conn->plan[IO_OUT].status == IO_POLLING)
128                 pfd->events |= POLLOUT;
129
130         if (pfd->events) {
131                 num_waiting++;
132                 pfd->fd = conn->fd.fd;
133         } else {
134                 pfd->fd = -conn->fd.fd;
135         }
136 }
137
138 void backend_wake(const void *wait)
139 {
140         unsigned int i;
141
142         for (i = 0; i < num_fds; i++) {
143                 struct io_conn *c;
144
145                 /* Ignore listeners */
146                 if (fds[i]->listener)
147                         continue;
148
149                 c = (void *)fds[i];
150                 if (c->plan[IO_IN].status == IO_WAITING
151                     && c->plan[IO_IN].arg.u1.const_vp == wait)
152                         io_do_wakeup(c, IO_IN);
153
154                 if (c->plan[IO_OUT].status == IO_WAITING
155                     && c->plan[IO_OUT].arg.u1.const_vp == wait)
156                         io_do_wakeup(c, IO_OUT);
157         }
158 }
159
160 static void destroy_conn(struct io_conn *conn, bool close_fd)
161 {
162         int saved_errno = errno;
163
164         if (close_fd)
165                 close(conn->fd.fd);
166         del_fd(&conn->fd);
167         /* In case it's on always list, remove it. */
168         list_del_init(&conn->always);
169
170         /* errno saved/restored by tal_free itself. */
171         if (conn->finish) {
172                 errno = saved_errno;
173                 conn->finish(conn, conn->finish_arg);
174         }
175 }
176
177 static void destroy_conn_close_fd(struct io_conn *conn)
178 {
179         destroy_conn(conn, true);
180 }
181
182 bool add_conn(struct io_conn *c)
183 {
184         if (!add_fd(&c->fd, 0))
185                 return false;
186         tal_add_destructor(c, destroy_conn_close_fd);
187         return true;
188 }
189
190 struct io_plan *io_close_taken_fd(struct io_conn *conn)
191 {
192         tal_del_destructor(conn, destroy_conn_close_fd);
193         destroy_conn(conn, false);
194         return io_close(conn);
195 }
196
197 static void accept_conn(struct io_listener *l)
198 {
199         int fd = accept(l->fd.fd, NULL, NULL);
200
201         /* FIXME: What to do here? */
202         if (fd < 0)
203                 return;
204
205         io_new_conn(l->ctx, fd, l->init, l->arg);
206 }
207
208 static bool handle_always(void)
209 {
210         bool ret = false;
211         struct io_conn *conn;
212
213         while ((conn = list_pop(&always, struct io_conn, always)) != NULL) {
214                 assert(conn->plan[IO_IN].status == IO_ALWAYS
215                        || conn->plan[IO_OUT].status == IO_ALWAYS);
216
217                 /* Re-initialize, for next time. */
218                 list_node_init(&conn->always);
219                 io_do_always(conn);
220                 ret = true;
221         }
222         return ret;
223 }
224
225 /* This is the main loop. */
226 void *io_loop(struct timers *timers, struct timer **expired)
227 {
228         void *ret;
229
230         /* if timers is NULL, expired must be.  If not, not. */
231         assert(!timers == !expired);
232
233         /* Make sure this is NULL if we exit for some other reason. */
234         if (expired)
235                 *expired = NULL;
236
237         while (!io_loop_return) {
238                 int i, r, ms_timeout = -1;
239
240                 if (handle_always()) {
241                         /* Could have started/finished more. */
242                         continue;
243                 }
244
245                 /* Everything closed? */
246                 if (num_fds == 0)
247                         break;
248
249                 /* You can't tell them all to go to sleep! */
250                 assert(num_waiting);
251
252                 if (timers) {
253                         struct timemono now, first;
254
255                         now = nowfn();
256
257                         /* Call functions for expired timers. */
258                         *expired = timers_expire(timers, now);
259                         if (*expired)
260                                 break;
261
262                         /* Now figure out how long to wait for the next one. */
263                         if (timer_earliest(timers, &first)) {
264                                 uint64_t next;
265                                 next = time_to_msec(timemono_between(first, now));
266                                 if (next < INT_MAX)
267                                         ms_timeout = next;
268                                 else
269                                         ms_timeout = INT_MAX;
270                         }
271                 }
272
273                 r = poll(pollfds, num_fds, ms_timeout);
274                 if (r < 0)
275                         break;
276
277                 for (i = 0; i < num_fds && !io_loop_return; i++) {
278                         struct io_conn *c = (void *)fds[i];
279                         int events = pollfds[i].revents;
280
281                         if (r == 0)
282                                 break;
283
284                         if (fds[i]->listener) {
285                                 struct io_listener *l = (void *)fds[i];
286                                 if (events & POLLIN) {
287                                         accept_conn(l);
288                                         r--;
289                                 } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
290                                         r--;
291                                         errno = EBADF;
292                                         io_close_listener(l);
293                                 }
294                         } else if (events & (POLLIN|POLLOUT)) {
295                                 r--;
296                                 io_ready(c, events);
297                         } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
298                                 r--;
299                                 errno = EBADF;
300                                 io_close(c);
301                         }
302                 }
303         }
304
305         ret = io_loop_return;
306         io_loop_return = NULL;
307
308         return ret;
309 }