]> git.ozlabs.org Git - ccan/blob - ccan/net/net.c
Merge branch 'io'
[ccan] / ccan / net / net.c
1 /* Licensed under BSD-MIT - see LICENSE file for details */
2 #include <ccan/net/net.h>
3 #include <ccan/noerr/noerr.h>
4 #include <sys/types.h>
5 #include <sys/socket.h>
6 #include <poll.h>
7 #include <netdb.h>
8 #include <string.h>
9 #include <stdlib.h>
10 #include <unistd.h>
11 #include <fcntl.h>
12 #include <errno.h>
13 #include <stdbool.h>
14 #include <netinet/in.h>
15 #include <assert.h>
16
17 struct addrinfo *net_client_lookup(const char *hostname,
18                                    const char *service,
19                                    int family,
20                                    int socktype)
21 {
22         struct addrinfo hints;
23         struct addrinfo *res;
24
25         memset(&hints, 0, sizeof(hints));
26         hints.ai_family = family;
27         hints.ai_socktype = socktype;
28         hints.ai_flags = 0;
29         hints.ai_protocol = 0;
30
31         if (getaddrinfo(hostname, service, &hints, &res) != 0)
32                 return NULL;
33
34         return res;
35 }
36
37 static bool set_nonblock(int fd, bool nonblock)
38 {
39         long flags;
40
41         flags = fcntl(fd, F_GETFL);
42         if (flags == -1)
43                 return false;
44
45         if (nonblock)
46                 flags |= O_NONBLOCK;
47         else
48                 flags &= ~(long)O_NONBLOCK;
49
50         return (fcntl(fd, F_SETFL, flags) == 0);
51 }
52
53 static int start_connect(const struct addrinfo *addr, bool *immediate)
54 {
55         int fd;
56
57         *immediate = false;
58
59         fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
60         if (fd == -1)
61                 return fd;
62
63         if (!set_nonblock(fd, true))
64                 goto close;
65
66         if (connect(fd, addr->ai_addr, addr->ai_addrlen) == 0) {
67                 /* Immediate connect. */
68                 *immediate = true;
69                 return fd;
70         }
71
72         if (errno == EINPROGRESS)
73                 return fd;
74
75 close:
76         close_noerr(fd);
77         return -1;
78 }
79
80
81 int net_connect_async(const struct addrinfo *addrinfo, struct pollfd pfds[2])
82 {
83         const struct addrinfo *addr[2] = { NULL, NULL };
84         unsigned int i;
85
86         pfds[0].fd = pfds[1].fd = -1;
87         pfds[0].events = pfds[1].events = POLLOUT;
88
89         /* Give IPv6 a slight advantage, by trying it first. */
90         for (; addrinfo; addrinfo = addrinfo->ai_next) {
91                 switch (addrinfo->ai_family) {
92                 case AF_INET:
93                         addr[1] = addrinfo;
94                         break;
95                 case AF_INET6:
96                         addr[0] = addrinfo;
97                         break;
98                 default:
99                         continue;
100                 }
101         }
102
103         /* In case we found nothing. */
104         errno = ENOENT;
105         for (i = 0; i < 2; i++) {
106                 bool immediate;
107
108                 if (!addr[i])
109                         continue;
110
111                 pfds[i].fd = start_connect(addr[i], &immediate);
112                 if (immediate) {
113                         if (pfds[!i].fd != -1)
114                                 close(pfds[!i].fd);
115                         if (!set_nonblock(pfds[i].fd, false)) {
116                                 close_noerr(pfds[i].fd);
117                                 return -1;
118                         }
119                         return pfds[0].fd;
120                 }
121         }
122
123         if (pfds[0].fd != -1 || pfds[1].fd != -1)
124                 errno = EINPROGRESS;
125         return -1;
126 }
127
128 void net_connect_abort(struct pollfd pfds[2])
129 {
130         unsigned int i;
131
132         for (i = 0; i < 2; i++) {
133                 if (pfds[i].fd != -1)
134                         close_noerr(pfds[i].fd);
135                 pfds[i].fd = -1;
136         }
137 }
138
139 int net_connect_complete(struct pollfd pfds[2])
140 {
141         unsigned int i;
142
143         assert(pfds[0].fd != -1 || pfds[1].fd != -1);
144
145         for (i = 0; i < 2; i++) {
146                 int err;
147                 socklen_t errlen = sizeof(err);
148
149                 if (pfds[i].fd == -1)
150                         continue;
151                 if (getsockopt(pfds[i].fd, SOL_SOCKET, SO_ERROR, &err,
152                                &errlen) != 0) {
153                         net_connect_abort(pfds);
154                         return -1;
155                 }
156                 if (err == 0) {
157                         /* Don't hand them non-blocking fd! */
158                         if (!set_nonblock(pfds[i].fd, false)) {
159                                 net_connect_abort(pfds);
160                                 return -1;
161                         }
162                         /* Close other one. */
163                         if (pfds[!i].fd != -1)
164                                 close(pfds[!i].fd);
165                         return pfds[i].fd;
166                 }
167         }
168
169         /* Still going... */
170         errno = EINPROGRESS;
171         return -1;
172 }
173
174 int net_connect(const struct addrinfo *addrinfo)
175 {
176         struct pollfd pfds[2];
177         int sockfd;
178
179         sockfd = net_connect_async(addrinfo, pfds);
180         /* Immediate connect or error is easy. */
181         if (sockfd >= 0 || errno != EINPROGRESS)
182                 return sockfd;
183
184         while (poll(pfds, 2, -1) != -1) {
185                 sockfd = net_connect_complete(pfds);
186                 if (sockfd >= 0 || errno != EINPROGRESS)
187                         return sockfd;
188         }
189
190         net_connect_abort(pfds);
191         return -1;
192 }
193
194 struct addrinfo *net_server_lookup(const char *service,
195                                    int family,
196                                    int socktype)
197 {
198         struct addrinfo *res, hints;
199
200         memset(&hints, 0, sizeof(hints));
201         hints.ai_family = family;
202         hints.ai_socktype = socktype;
203         hints.ai_flags = AI_PASSIVE;
204         hints.ai_protocol = 0;
205
206         if (getaddrinfo(NULL, service, &hints, &res) != 0)
207                 return NULL;
208
209         return res;
210 }
211
212 static bool should_listen(const struct addrinfo *addrinfo)
213 {
214 #ifdef SOCK_SEQPACKET
215         if (addrinfo->ai_socktype == SOCK_SEQPACKET)
216                 return true;
217 #endif
218         return (addrinfo->ai_socktype == SOCK_STREAM);
219 }
220
221 static int make_listen_fd(const struct addrinfo *addrinfo)
222 {
223         int saved_errno, fd, on = 1;
224
225         fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
226                     addrinfo->ai_protocol);
227         if (fd < 0)
228                 return -1;
229
230         setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
231         if (bind(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0)
232                 goto fail;
233
234         if (should_listen(addrinfo) && listen(fd, 5) != 0)
235                 goto fail;
236         return fd;
237
238 fail:
239         saved_errno = errno;
240         close(fd);
241         errno = saved_errno;
242         return -1;
243 }
244
245 int net_bind(const struct addrinfo *addrinfo, int fds[2])
246 {
247         const struct addrinfo *ipv6 = NULL;
248         const struct addrinfo *ipv4 = NULL;
249         unsigned int num;
250
251         if (addrinfo->ai_family == AF_INET)
252                 ipv4 = addrinfo;
253         else if (addrinfo->ai_family == AF_INET6)
254                 ipv6 = addrinfo;
255
256         if (addrinfo->ai_next) {
257                 if (addrinfo->ai_next->ai_family == AF_INET)
258                         ipv4 = addrinfo->ai_next;
259                 else if (addrinfo->ai_next->ai_family == AF_INET6)
260                         ipv6 = addrinfo->ai_next;
261         }
262
263         num = 0;
264         /* Take IPv6 first, since it might bind to IPv4 port too. */
265         if (ipv6) {
266                 if ((fds[num] = make_listen_fd(ipv6)) >= 0)
267                         num++;
268                 else
269                         ipv6 = NULL;
270         }
271         if (ipv4) {
272                 if ((fds[num] = make_listen_fd(ipv4)) >= 0)
273                         num++;
274                 else
275                         ipv4 = NULL;
276         }
277         if (num == 0)
278                 return -1;
279
280         return num;
281 }