]> git.ozlabs.org Git - ccan/blob - ccan/io/poll.c
io: make io_close_taken_fd() unset nonblocking on the fd.
[ccan] / ccan / io / poll.c
1 /* Licensed under LGPLv2.1+ - see LICENSE file for details */
2 #include "io.h"
3 #include "backend.h"
4 #include <assert.h>
5 #include <poll.h>
6 #include <stdlib.h>
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <limits.h>
10 #include <errno.h>
11 #include <ccan/time/time.h>
12 #include <ccan/timer/timer.h>
13
14 static size_t num_fds = 0, max_fds = 0, num_waiting = 0;
15 static struct pollfd *pollfds = NULL;
16 static struct fd **fds = NULL;
17 static LIST_HEAD(closing);
18 static LIST_HEAD(always);
19 static struct timemono (*nowfn)(void) = time_mono;
20
21 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
22 {
23         struct timemono (*old)(void) = nowfn;
24         nowfn = now;
25         return old;
26 }
27
28 static bool add_fd(struct fd *fd, short events)
29 {
30         if (!max_fds) {
31                 assert(num_fds == 0);
32                 pollfds = tal_arr(NULL, struct pollfd, 8);
33                 if (!pollfds)
34                         return false;
35                 fds = tal_arr(pollfds, struct fd *, 8);
36                 if (!fds)
37                         return false;
38                 max_fds = 8;
39         }
40
41         if (num_fds + 1 > max_fds) {
42                 size_t num = max_fds * 2;
43
44                 if (!tal_resize(&pollfds, num))
45                         return false;
46                 if (!tal_resize(&fds, num))
47                         return false;
48                 max_fds = num;
49         }
50
51         pollfds[num_fds].events = events;
52         /* In case it's idle. */
53         if (!events)
54                 pollfds[num_fds].fd = -fd->fd;
55         else
56                 pollfds[num_fds].fd = fd->fd;
57         pollfds[num_fds].revents = 0; /* In case we're iterating now */
58         fds[num_fds] = fd;
59         fd->backend_info = num_fds;
60         num_fds++;
61         if (events)
62                 num_waiting++;
63
64         return true;
65 }
66
67 static void del_fd(struct fd *fd)
68 {
69         size_t n = fd->backend_info;
70
71         assert(n != -1);
72         assert(n < num_fds);
73         if (pollfds[n].events)
74                 num_waiting--;
75         if (n != num_fds - 1) {
76                 /* Move last one over us. */
77                 pollfds[n] = pollfds[num_fds-1];
78                 fds[n] = fds[num_fds-1];
79                 assert(fds[n]->backend_info == num_fds-1);
80                 fds[n]->backend_info = n;
81         } else if (num_fds == 1) {
82                 /* Free everything when no more fds. */
83                 pollfds = tal_free(pollfds);
84                 fds = NULL;
85                 max_fds = 0;
86         }
87         num_fds--;
88         fd->backend_info = -1;
89 }
90
91 static void destroy_listener(struct io_listener *l)
92 {
93         close(l->fd.fd);
94         del_fd(&l->fd);
95 }
96
97 bool add_listener(struct io_listener *l)
98 {
99         if (!add_fd(&l->fd, POLLIN))
100                 return false;
101         tal_add_destructor(l, destroy_listener);
102         return true;
103 }
104
105 void remove_from_always(struct io_conn *conn)
106 {
107         list_del_init(&conn->always);
108 }
109
110 void backend_new_always(struct io_conn *conn)
111 {
112         /* In case it's already in always list. */
113         list_del(&conn->always);
114         list_add_tail(&always, &conn->always);
115 }
116
117 void backend_new_plan(struct io_conn *conn)
118 {
119         struct pollfd *pfd = &pollfds[conn->fd.backend_info];
120
121         if (pfd->events)
122                 num_waiting--;
123
124         pfd->events = 0;
125         if (conn->plan[IO_IN].status == IO_POLLING)
126                 pfd->events |= POLLIN;
127         if (conn->plan[IO_OUT].status == IO_POLLING)
128                 pfd->events |= POLLOUT;
129
130         if (pfd->events) {
131                 num_waiting++;
132                 pfd->fd = conn->fd.fd;
133         } else {
134                 pfd->fd = -conn->fd.fd;
135         }
136 }
137
138 void backend_wake(const void *wait)
139 {
140         unsigned int i;
141
142         for (i = 0; i < num_fds; i++) {
143                 struct io_conn *c;
144
145                 /* Ignore listeners */
146                 if (fds[i]->listener)
147                         continue;
148
149                 c = (void *)fds[i];
150                 if (c->plan[IO_IN].status == IO_WAITING
151                     && c->plan[IO_IN].arg.u1.const_vp == wait)
152                         io_do_wakeup(c, IO_IN);
153
154                 if (c->plan[IO_OUT].status == IO_WAITING
155                     && c->plan[IO_OUT].arg.u1.const_vp == wait)
156                         io_do_wakeup(c, IO_OUT);
157         }
158 }
159
160 static void destroy_conn(struct io_conn *conn, bool close_fd)
161 {
162         int saved_errno = errno;
163
164         if (close_fd)
165                 close(conn->fd.fd);
166         del_fd(&conn->fd);
167         /* In case it's on always list, remove it. */
168         list_del_init(&conn->always);
169
170         /* errno saved/restored by tal_free itself. */
171         if (conn->finish) {
172                 errno = saved_errno;
173                 conn->finish(conn, conn->finish_arg);
174         }
175 }
176
177 static void destroy_conn_close_fd(struct io_conn *conn)
178 {
179         destroy_conn(conn, true);
180 }
181
182 bool add_conn(struct io_conn *c)
183 {
184         if (!add_fd(&c->fd, 0))
185                 return false;
186         tal_add_destructor(c, destroy_conn_close_fd);
187         return true;
188 }
189
190 void cleanup_conn_without_close(struct io_conn *conn)
191 {
192         tal_del_destructor(conn, destroy_conn_close_fd);
193         destroy_conn(conn, false);
194 }
195
196 static void accept_conn(struct io_listener *l)
197 {
198         int fd = accept(l->fd.fd, NULL, NULL);
199
200         /* FIXME: What to do here? */
201         if (fd < 0)
202                 return;
203
204         io_new_conn(l->ctx, fd, l->init, l->arg);
205 }
206
207 static bool handle_always(void)
208 {
209         bool ret = false;
210         struct io_conn *conn;
211
212         while ((conn = list_pop(&always, struct io_conn, always)) != NULL) {
213                 assert(conn->plan[IO_IN].status == IO_ALWAYS
214                        || conn->plan[IO_OUT].status == IO_ALWAYS);
215
216                 /* Re-initialize, for next time. */
217                 list_node_init(&conn->always);
218                 io_do_always(conn);
219                 ret = true;
220         }
221         return ret;
222 }
223
224 /* This is the main loop. */
225 void *io_loop(struct timers *timers, struct timer **expired)
226 {
227         void *ret;
228
229         /* if timers is NULL, expired must be.  If not, not. */
230         assert(!timers == !expired);
231
232         /* Make sure this is NULL if we exit for some other reason. */
233         if (expired)
234                 *expired = NULL;
235
236         while (!io_loop_return) {
237                 int i, r, ms_timeout = -1;
238
239                 if (handle_always()) {
240                         /* Could have started/finished more. */
241                         continue;
242                 }
243
244                 /* Everything closed? */
245                 if (num_fds == 0)
246                         break;
247
248                 /* You can't tell them all to go to sleep! */
249                 assert(num_waiting);
250
251                 if (timers) {
252                         struct timemono now, first;
253
254                         now = nowfn();
255
256                         /* Call functions for expired timers. */
257                         *expired = timers_expire(timers, now);
258                         if (*expired)
259                                 break;
260
261                         /* Now figure out how long to wait for the next one. */
262                         if (timer_earliest(timers, &first)) {
263                                 uint64_t next;
264                                 next = time_to_msec(timemono_between(first, now));
265                                 if (next < INT_MAX)
266                                         ms_timeout = next;
267                                 else
268                                         ms_timeout = INT_MAX;
269                         }
270                 }
271
272                 r = poll(pollfds, num_fds, ms_timeout);
273                 if (r < 0)
274                         break;
275
276                 for (i = 0; i < num_fds && !io_loop_return; i++) {
277                         struct io_conn *c = (void *)fds[i];
278                         int events = pollfds[i].revents;
279
280                         if (r == 0)
281                                 break;
282
283                         if (fds[i]->listener) {
284                                 struct io_listener *l = (void *)fds[i];
285                                 if (events & POLLIN) {
286                                         accept_conn(l);
287                                         r--;
288                                 } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
289                                         r--;
290                                         errno = EBADF;
291                                         io_close_listener(l);
292                                 }
293                         } else if (events & (POLLIN|POLLOUT)) {
294                                 r--;
295                                 io_ready(c, events);
296                         } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
297                                 r--;
298                                 errno = EBADF;
299                                 io_close(c);
300                         }
301                 }
302         }
303
304         ret = io_loop_return;
305         io_loop_return = NULL;
306
307         return ret;
308 }