b005a97e4b1d07879a4e281a8d48276baaac2fa1
[ccan] / ccan / io / poll.c
1 /* Licensed under LGPLv2.1+ - see LICENSE file for details */
2 #include "io.h"
3 #include "backend.h"
4 #include <assert.h>
5 #include <poll.h>
6 #include <stdlib.h>
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <limits.h>
10 #include <errno.h>
11 #include <ccan/time/time.h>
12 #include <ccan/timer/timer.h>
13
14 static size_t num_fds = 0, max_fds = 0, num_waiting = 0;
15 static struct pollfd *pollfds = NULL;
16 static struct fd **fds = NULL;
17 static LIST_HEAD(closing);
18 static LIST_HEAD(always);
19 static struct timemono (*nowfn)(void) = time_mono;
20 static int (*pollfn)(struct pollfd *fds, nfds_t nfds, int timeout) = poll;
21
22 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
23 {
24         struct timemono (*old)(void) = nowfn;
25         nowfn = now;
26         return old;
27 }
28
29 int (*io_poll_override(int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout)))(struct pollfd *, nfds_t, int)
30 {
31         int (*old)(struct pollfd *fds, nfds_t nfds, int timeout) = pollfn;
32         pollfn = poll;
33         return old;
34 }
35
36 static bool add_fd(struct fd *fd, short events)
37 {
38         if (!max_fds) {
39                 assert(num_fds == 0);
40                 pollfds = tal_arr(NULL, struct pollfd, 8);
41                 if (!pollfds)
42                         return false;
43                 fds = tal_arr(pollfds, struct fd *, 8);
44                 if (!fds)
45                         return false;
46                 max_fds = 8;
47         }
48
49         if (num_fds + 1 > max_fds) {
50                 size_t num = max_fds * 2;
51
52                 if (!tal_resize(&pollfds, num))
53                         return false;
54                 if (!tal_resize(&fds, num))
55                         return false;
56                 max_fds = num;
57         }
58
59         pollfds[num_fds].events = events;
60         /* In case it's idle. */
61         if (!events)
62                 pollfds[num_fds].fd = -fd->fd;
63         else
64                 pollfds[num_fds].fd = fd->fd;
65         pollfds[num_fds].revents = 0; /* In case we're iterating now */
66         fds[num_fds] = fd;
67         fd->backend_info = num_fds;
68         num_fds++;
69         if (events)
70                 num_waiting++;
71
72         return true;
73 }
74
75 static void del_fd(struct fd *fd)
76 {
77         size_t n = fd->backend_info;
78
79         assert(n != -1);
80         assert(n < num_fds);
81         if (pollfds[n].events)
82                 num_waiting--;
83         if (n != num_fds - 1) {
84                 /* Move last one over us. */
85                 pollfds[n] = pollfds[num_fds-1];
86                 fds[n] = fds[num_fds-1];
87                 assert(fds[n]->backend_info == num_fds-1);
88                 fds[n]->backend_info = n;
89         } else if (num_fds == 1) {
90                 /* Free everything when no more fds. */
91                 pollfds = tal_free(pollfds);
92                 fds = NULL;
93                 max_fds = 0;
94         }
95         num_fds--;
96         fd->backend_info = -1;
97 }
98
99 static void destroy_listener(struct io_listener *l)
100 {
101         close(l->fd.fd);
102         del_fd(&l->fd);
103 }
104
105 bool add_listener(struct io_listener *l)
106 {
107         if (!add_fd(&l->fd, POLLIN))
108                 return false;
109         tal_add_destructor(l, destroy_listener);
110         return true;
111 }
112
113 void remove_from_always(struct io_conn *conn)
114 {
115         list_del_init(&conn->always);
116 }
117
118 void backend_new_always(struct io_conn *conn)
119 {
120         /* In case it's already in always list. */
121         list_del(&conn->always);
122         list_add_tail(&always, &conn->always);
123 }
124
125 void backend_new_plan(struct io_conn *conn)
126 {
127         struct pollfd *pfd = &pollfds[conn->fd.backend_info];
128
129         if (pfd->events)
130                 num_waiting--;
131
132         pfd->events = 0;
133         if (conn->plan[IO_IN].status == IO_POLLING_NOTSTARTED
134             || conn->plan[IO_IN].status == IO_POLLING_STARTED)
135                 pfd->events |= POLLIN;
136         if (conn->plan[IO_OUT].status == IO_POLLING_NOTSTARTED
137             || conn->plan[IO_OUT].status == IO_POLLING_STARTED)
138                 pfd->events |= POLLOUT;
139
140         if (pfd->events) {
141                 num_waiting++;
142                 pfd->fd = conn->fd.fd;
143         } else {
144                 pfd->fd = -conn->fd.fd;
145         }
146 }
147
148 void backend_wake(const void *wait)
149 {
150         unsigned int i;
151
152         for (i = 0; i < num_fds; i++) {
153                 struct io_conn *c;
154
155                 /* Ignore listeners */
156                 if (fds[i]->listener)
157                         continue;
158
159                 c = (void *)fds[i];
160                 if (c->plan[IO_IN].status == IO_WAITING
161                     && c->plan[IO_IN].arg.u1.const_vp == wait)
162                         io_do_wakeup(c, IO_IN);
163
164                 if (c->plan[IO_OUT].status == IO_WAITING
165                     && c->plan[IO_OUT].arg.u1.const_vp == wait)
166                         io_do_wakeup(c, IO_OUT);
167         }
168 }
169
170 static void destroy_conn(struct io_conn *conn, bool close_fd)
171 {
172         int saved_errno = errno;
173
174         if (close_fd)
175                 close(conn->fd.fd);
176         del_fd(&conn->fd);
177         /* In case it's on always list, remove it. */
178         list_del_init(&conn->always);
179
180         /* errno saved/restored by tal_free itself. */
181         if (conn->finish) {
182                 errno = saved_errno;
183                 conn->finish(conn, conn->finish_arg);
184         }
185 }
186
187 static void destroy_conn_close_fd(struct io_conn *conn)
188 {
189         destroy_conn(conn, true);
190 }
191
192 bool add_conn(struct io_conn *c)
193 {
194         if (!add_fd(&c->fd, 0))
195                 return false;
196         tal_add_destructor(c, destroy_conn_close_fd);
197         return true;
198 }
199
200 void cleanup_conn_without_close(struct io_conn *conn)
201 {
202         tal_del_destructor(conn, destroy_conn_close_fd);
203         destroy_conn(conn, false);
204 }
205
206 static void accept_conn(struct io_listener *l)
207 {
208         int fd = accept(l->fd.fd, NULL, NULL);
209
210         /* FIXME: What to do here? */
211         if (fd < 0)
212                 return;
213
214         io_new_conn(l->ctx, fd, l->init, l->arg);
215 }
216
217 static bool handle_always(void)
218 {
219         bool ret = false;
220         struct io_conn *conn;
221
222         while ((conn = list_pop(&always, struct io_conn, always)) != NULL) {
223                 assert(conn->plan[IO_IN].status == IO_ALWAYS
224                        || conn->plan[IO_OUT].status == IO_ALWAYS);
225
226                 /* Re-initialize, for next time. */
227                 list_node_init(&conn->always);
228                 io_do_always(conn);
229                 ret = true;
230         }
231         return ret;
232 }
233
234 /* This is the main loop. */
235 void *io_loop(struct timers *timers, struct timer **expired)
236 {
237         void *ret;
238
239         /* if timers is NULL, expired must be.  If not, not. */
240         assert(!timers == !expired);
241
242         /* Make sure this is NULL if we exit for some other reason. */
243         if (expired)
244                 *expired = NULL;
245
246         while (!io_loop_return) {
247                 int i, r, ms_timeout = -1;
248
249                 if (handle_always()) {
250                         /* Could have started/finished more. */
251                         continue;
252                 }
253
254                 /* Everything closed? */
255                 if (num_fds == 0)
256                         break;
257
258                 /* You can't tell them all to go to sleep! */
259                 assert(num_waiting);
260
261                 if (timers) {
262                         struct timemono now, first;
263
264                         now = nowfn();
265
266                         /* Call functions for expired timers. */
267                         *expired = timers_expire(timers, now);
268                         if (*expired)
269                                 break;
270
271                         /* Now figure out how long to wait for the next one. */
272                         if (timer_earliest(timers, &first)) {
273                                 uint64_t next;
274                                 next = time_to_msec(timemono_between(first, now));
275                                 if (next < INT_MAX)
276                                         ms_timeout = next;
277                                 else
278                                         ms_timeout = INT_MAX;
279                         }
280                 }
281
282                 r = pollfn(pollfds, num_fds, ms_timeout);
283                 if (r < 0) {
284                         /* Signals shouldn't break us, unless they set
285                          * io_loop_return. */
286                         if (errno == EINTR)
287                                 continue;
288                         break;
289                 }
290
291                 for (i = 0; i < num_fds && !io_loop_return; i++) {
292                         struct io_conn *c = (void *)fds[i];
293                         int events = pollfds[i].revents;
294
295                         if (r == 0)
296                                 break;
297
298                         if (fds[i]->listener) {
299                                 struct io_listener *l = (void *)fds[i];
300                                 if (events & POLLIN) {
301                                         accept_conn(l);
302                                         r--;
303                                 } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
304                                         r--;
305                                         errno = EBADF;
306                                         io_close_listener(l);
307                                 }
308                         } else if (events & (POLLIN|POLLOUT)) {
309                                 r--;
310                                 io_ready(c, events);
311                         } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
312                                 r--;
313                                 errno = EBADF;
314                                 io_close(c);
315                         }
316                 }
317         }
318
319         ret = io_loop_return;
320         io_loop_return = NULL;
321
322         return ret;
323 }