]> git.ozlabs.org Git - ccan/blob - ccan/io/poll.c
io: don't try to close() connection twice, remove shutdown logic.
[ccan] / ccan / io / poll.c
1 /* Licensed under LGPLv2.1+ - see LICENSE file for details */
2 #include "io.h"
3 #include "backend.h"
4 #include <assert.h>
5 #include <poll.h>
6 #include <stdlib.h>
7 #include <sys/types.h>
8 #include <sys/socket.h>
9 #include <limits.h>
10 #include <errno.h>
11 #include <ccan/time/time.h>
12 #include <ccan/timer/timer.h>
13
14 static size_t num_fds = 0, max_fds = 0, num_waiting = 0;
15 static struct pollfd *pollfds = NULL;
16 static struct fd **fds = NULL;
17 static LIST_HEAD(closing);
18 static LIST_HEAD(always);
19 static struct timemono (*nowfn)(void) = time_mono;
20
21 struct timemono (*io_time_override(struct timemono (*now)(void)))(void)
22 {
23         struct timemono (*old)(void) = nowfn;
24         nowfn = now;
25         return old;
26 }
27
28 static bool add_fd(struct fd *fd, short events)
29 {
30         if (!max_fds) {
31                 assert(num_fds == 0);
32                 pollfds = tal_arr(NULL, struct pollfd, 8);
33                 if (!pollfds)
34                         return false;
35                 fds = tal_arr(pollfds, struct fd *, 8);
36                 if (!fds)
37                         return false;
38                 max_fds = 8;
39         }
40
41         if (num_fds + 1 > max_fds) {
42                 size_t num = max_fds * 2;
43
44                 if (!tal_resize(&pollfds, num))
45                         return false;
46                 if (!tal_resize(&fds, num))
47                         return false;
48                 max_fds = num;
49         }
50
51         pollfds[num_fds].events = events;
52         /* In case it's idle. */
53         if (!events)
54                 pollfds[num_fds].fd = -fd->fd;
55         else
56                 pollfds[num_fds].fd = fd->fd;
57         pollfds[num_fds].revents = 0; /* In case we're iterating now */
58         fds[num_fds] = fd;
59         fd->backend_info = num_fds;
60         num_fds++;
61         if (events)
62                 num_waiting++;
63
64         return true;
65 }
66
67 static void del_fd(struct fd *fd)
68 {
69         size_t n = fd->backend_info;
70
71         assert(n != -1);
72         assert(n < num_fds);
73         if (pollfds[n].events)
74                 num_waiting--;
75         if (n != num_fds - 1) {
76                 /* Move last one over us. */
77                 pollfds[n] = pollfds[num_fds-1];
78                 fds[n] = fds[num_fds-1];
79                 assert(fds[n]->backend_info == num_fds-1);
80                 fds[n]->backend_info = n;
81         } else if (num_fds == 1) {
82                 /* Free everything when no more fds. */
83                 pollfds = tal_free(pollfds);
84                 fds = NULL;
85                 max_fds = 0;
86         }
87         num_fds--;
88         fd->backend_info = -1;
89 }
90
91 static void destroy_listener(struct io_listener *l)
92 {
93         close(l->fd.fd);
94         del_fd(&l->fd);
95 }
96
97 bool add_listener(struct io_listener *l)
98 {
99         if (!add_fd(&l->fd, POLLIN))
100                 return false;
101         tal_add_destructor(l, destroy_listener);
102         return true;
103 }
104
105 void remove_from_always(struct io_conn *conn)
106 {
107         list_del_init(&conn->always);
108 }
109
110 void backend_new_always(struct io_conn *conn)
111 {
112         /* In case it's already in always list. */
113         list_del(&conn->always);
114         list_add_tail(&always, &conn->always);
115 }
116
117 void backend_new_plan(struct io_conn *conn)
118 {
119         struct pollfd *pfd = &pollfds[conn->fd.backend_info];
120
121         if (pfd->events)
122                 num_waiting--;
123
124         pfd->events = 0;
125         if (conn->plan[IO_IN].status == IO_POLLING)
126                 pfd->events |= POLLIN;
127         if (conn->plan[IO_OUT].status == IO_POLLING)
128                 pfd->events |= POLLOUT;
129
130         if (pfd->events) {
131                 num_waiting++;
132                 pfd->fd = conn->fd.fd;
133         } else {
134                 pfd->fd = -conn->fd.fd;
135         }
136 }
137
138 void backend_wake(const void *wait)
139 {
140         unsigned int i;
141
142         for (i = 0; i < num_fds; i++) {
143                 struct io_conn *c;
144
145                 /* Ignore listeners */
146                 if (fds[i]->listener)
147                         continue;
148
149                 c = (void *)fds[i];
150                 if (c->plan[IO_IN].status == IO_WAITING
151                     && c->plan[IO_IN].arg.u1.const_vp == wait)
152                         io_do_wakeup(c, IO_IN);
153
154                 if (c->plan[IO_OUT].status == IO_WAITING
155                     && c->plan[IO_OUT].arg.u1.const_vp == wait)
156                         io_do_wakeup(c, IO_OUT);
157         }
158 }
159
160 static void destroy_conn(struct io_conn *conn)
161 {
162         int saved_errno = errno;
163
164         close(conn->fd.fd);
165         del_fd(&conn->fd);
166         /* In case it's on always list, remove it. */
167         list_del_init(&conn->always);
168
169         /* errno saved/restored by tal_free itself. */
170         if (conn->finish) {
171                 errno = saved_errno;
172                 conn->finish(conn, conn->finish_arg);
173         }
174 }
175
176 bool add_conn(struct io_conn *c)
177 {
178         if (!add_fd(&c->fd, 0))
179                 return false;
180         tal_add_destructor(c, destroy_conn);
181         return true;
182 }
183
184 static void accept_conn(struct io_listener *l)
185 {
186         int fd = accept(l->fd.fd, NULL, NULL);
187
188         /* FIXME: What to do here? */
189         if (fd < 0)
190                 return;
191
192         io_new_conn(l->ctx, fd, l->init, l->arg);
193 }
194
195 static bool handle_always(void)
196 {
197         bool ret = false;
198         struct io_conn *conn;
199
200         while ((conn = list_pop(&always, struct io_conn, always)) != NULL) {
201                 assert(conn->plan[IO_IN].status == IO_ALWAYS
202                        || conn->plan[IO_OUT].status == IO_ALWAYS);
203
204                 /* Re-initialize, for next time. */
205                 list_node_init(&conn->always);
206                 io_do_always(conn);
207                 ret = true;
208         }
209         return ret;
210 }
211
212 /* This is the main loop. */
213 void *io_loop(struct timers *timers, struct timer **expired)
214 {
215         void *ret;
216
217         /* if timers is NULL, expired must be.  If not, not. */
218         assert(!timers == !expired);
219
220         /* Make sure this is NULL if we exit for some other reason. */
221         if (expired)
222                 *expired = NULL;
223
224         while (!io_loop_return) {
225                 int i, r, ms_timeout = -1;
226
227                 if (handle_always()) {
228                         /* Could have started/finished more. */
229                         continue;
230                 }
231
232                 /* Everything closed? */
233                 if (num_fds == 0)
234                         break;
235
236                 /* You can't tell them all to go to sleep! */
237                 assert(num_waiting);
238
239                 if (timers) {
240                         struct timemono now, first;
241
242                         now = nowfn();
243
244                         /* Call functions for expired timers. */
245                         *expired = timers_expire(timers, now);
246                         if (*expired)
247                                 break;
248
249                         /* Now figure out how long to wait for the next one. */
250                         if (timer_earliest(timers, &first)) {
251                                 uint64_t next;
252                                 next = time_to_msec(timemono_between(first, now));
253                                 if (next < INT_MAX)
254                                         ms_timeout = next;
255                                 else
256                                         ms_timeout = INT_MAX;
257                         }
258                 }
259
260                 r = poll(pollfds, num_fds, ms_timeout);
261                 if (r < 0)
262                         break;
263
264                 for (i = 0; i < num_fds && !io_loop_return; i++) {
265                         struct io_conn *c = (void *)fds[i];
266                         int events = pollfds[i].revents;
267
268                         if (r == 0)
269                                 break;
270
271                         if (fds[i]->listener) {
272                                 if (events & POLLIN) {
273                                         accept_conn((void *)c);
274                                         r--;
275                                 }
276                         } else if (events & (POLLIN|POLLOUT)) {
277                                 r--;
278                                 io_ready(c, events);
279                         } else if (events & (POLLHUP|POLLNVAL|POLLERR)) {
280                                 r--;
281                                 errno = EBADF;
282                                 io_close(c);
283                         }
284                 }
285         }
286
287         ret = io_loop_return;
288         io_loop_return = NULL;
289
290         return ret;
291 }