net: fix ipv4 immediate connect
[ccan] / ccan / net / net.c
1 /* Licensed under BSD-MIT - see LICENSE file for details */
2 #include <ccan/net/net.h>
3 #include <ccan/noerr/noerr.h>
4 #include <sys/types.h>
5 #include <sys/socket.h>
6 #include <poll.h>
7 #include <netdb.h>
8 #include <string.h>
9 #include <stdlib.h>
10 #include <unistd.h>
11 #include <fcntl.h>
12 #include <errno.h>
13 #include <stdbool.h>
14 #include <netinet/in.h>
15 #include <assert.h>
16
17 struct addrinfo *net_client_lookup(const char *hostname,
18                                    const char *service,
19                                    int family,
20                                    int socktype)
21 {
22         struct addrinfo hints;
23         struct addrinfo *res;
24
25         memset(&hints, 0, sizeof(hints));
26         hints.ai_family = family;
27         hints.ai_socktype = socktype;
28         hints.ai_flags = 0;
29         hints.ai_protocol = 0;
30
31         if (getaddrinfo(hostname, service, &hints, &res) != 0)
32                 return NULL;
33
34         return res;
35 }
36
37 static bool set_nonblock(int fd, bool nonblock)
38 {
39         long flags;
40
41         flags = fcntl(fd, F_GETFL);
42         if (flags == -1)
43                 return false;
44
45         if (nonblock)
46                 flags |= O_NONBLOCK;
47         else
48                 flags &= ~(long)O_NONBLOCK;
49
50         return (fcntl(fd, F_SETFL, flags) == 0);
51 }
52
53 static int start_connect(const struct addrinfo *addr, bool *immediate)
54 {
55         int fd;
56
57         *immediate = false;
58
59         fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
60         if (fd == -1)
61                 return fd;
62
63         if (!set_nonblock(fd, true))
64                 goto close;
65
66         if (connect(fd, addr->ai_addr, addr->ai_addrlen) == 0) {
67                 /* Immediate connect. */
68                 *immediate = true;
69                 return fd;
70         }
71
72         if (errno == EINPROGRESS)
73                 return fd;
74
75 close:
76         close_noerr(fd);
77         return -1;
78 }
79
80
81 int net_connect_async(const struct addrinfo *addrinfo, struct pollfd pfds[2])
82 {
83         const struct addrinfo *addr[2] = { NULL, NULL };
84         unsigned int i;
85
86         pfds[0].fd = pfds[1].fd = -1;
87         pfds[0].events = pfds[1].events = POLLOUT;
88
89         /* Give IPv6 a slight advantage, by trying it first. */
90         for (; addrinfo; addrinfo = addrinfo->ai_next) {
91                 switch (addrinfo->ai_family) {
92                 case AF_INET:
93                         addr[1] = addrinfo;
94                         break;
95                 case AF_INET6:
96                         addr[0] = addrinfo;
97                         break;
98                 default:
99                         continue;
100                 }
101         }
102
103         /* In case we found nothing. */
104         errno = ENOENT;
105         for (i = 0; i < 2; i++) {
106                 bool immediate;
107
108                 if (!addr[i])
109                         continue;
110
111                 pfds[i].fd = start_connect(addr[i], &immediate);
112                 if (immediate) {
113                         if (pfds[!i].fd != -1)
114                                 close(pfds[!i].fd);
115                         if (!set_nonblock(pfds[i].fd, false)) {
116                                 close_noerr(pfds[i].fd);
117                                 return -1;
118                         }
119                         return pfds[i].fd;
120                 }
121         }
122
123         if (pfds[0].fd != -1 || pfds[1].fd != -1)
124                 errno = EINPROGRESS;
125         return -1;
126 }
127
128 void net_connect_abort(struct pollfd pfds[2])
129 {
130         unsigned int i;
131
132         for (i = 0; i < 2; i++) {
133                 if (pfds[i].fd != -1)
134                         close_noerr(pfds[i].fd);
135                 pfds[i].fd = -1;
136         }
137 }
138
139 int net_connect_complete(struct pollfd pfds[2])
140 {
141         unsigned int i;
142
143         assert(pfds[0].fd != -1 || pfds[1].fd != -1);
144
145         for (i = 0; i < 2; i++) {
146                 int err;
147                 socklen_t errlen = sizeof(err);
148
149                 if (pfds[i].fd == -1)
150                         continue;
151                 if (pfds[i].revents & POLLHUP) {
152                         /* Linux gives this if connecting to local
153                          * non-listening port */
154                         close(pfds[i].fd);
155                         pfds[i].fd = -1;
156                         if (pfds[!i].fd == -1) {
157                                 errno = ECONNREFUSED;
158                                 return -1;
159                         }
160                         continue;
161                 }
162                 if (getsockopt(pfds[i].fd, SOL_SOCKET, SO_ERROR, &err,
163                                &errlen) != 0) {
164                         net_connect_abort(pfds);
165                         return -1;
166                 }
167                 if (err == 0) {
168                         /* Don't hand them non-blocking fd! */
169                         if (!set_nonblock(pfds[i].fd, false)) {
170                                 net_connect_abort(pfds);
171                                 return -1;
172                         }
173                         /* Close other one. */
174                         if (pfds[!i].fd != -1)
175                                 close(pfds[!i].fd);
176                         return pfds[i].fd;
177                 }
178         }
179
180         /* Still going... */
181         errno = EINPROGRESS;
182         return -1;
183 }
184
185 int net_connect(const struct addrinfo *addrinfo)
186 {
187         struct pollfd pfds[2];
188         int sockfd;
189
190         sockfd = net_connect_async(addrinfo, pfds);
191         /* Immediate connect or error is easy. */
192         if (sockfd >= 0 || errno != EINPROGRESS)
193                 return sockfd;
194
195         while (poll(pfds, 2, -1) != -1) {
196                 sockfd = net_connect_complete(pfds);
197                 if (sockfd >= 0 || errno != EINPROGRESS)
198                         return sockfd;
199         }
200
201         net_connect_abort(pfds);
202         return -1;
203 }
204
205 struct addrinfo *net_server_lookup(const char *service,
206                                    int family,
207                                    int socktype)
208 {
209         struct addrinfo *res, hints;
210
211         memset(&hints, 0, sizeof(hints));
212         hints.ai_family = family;
213         hints.ai_socktype = socktype;
214         hints.ai_flags = AI_PASSIVE;
215         hints.ai_protocol = 0;
216
217         if (getaddrinfo(NULL, service, &hints, &res) != 0)
218                 return NULL;
219
220         return res;
221 }
222
223 static bool should_listen(const struct addrinfo *addrinfo)
224 {
225 #ifdef SOCK_SEQPACKET
226         if (addrinfo->ai_socktype == SOCK_SEQPACKET)
227                 return true;
228 #endif
229         return (addrinfo->ai_socktype == SOCK_STREAM);
230 }
231
232 static int make_listen_fd(const struct addrinfo *addrinfo)
233 {
234         int saved_errno, fd, on = 1;
235
236         fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
237                     addrinfo->ai_protocol);
238         if (fd < 0)
239                 return -1;
240
241         setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
242         if (bind(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0)
243                 goto fail;
244
245         if (should_listen(addrinfo) && listen(fd, 5) != 0)
246                 goto fail;
247         return fd;
248
249 fail:
250         saved_errno = errno;
251         close(fd);
252         errno = saved_errno;
253         return -1;
254 }
255
256 int net_bind(const struct addrinfo *addrinfo, int fds[2])
257 {
258         const struct addrinfo *ipv6 = NULL;
259         const struct addrinfo *ipv4 = NULL;
260         unsigned int num;
261
262         if (addrinfo->ai_family == AF_INET)
263                 ipv4 = addrinfo;
264         else if (addrinfo->ai_family == AF_INET6)
265                 ipv6 = addrinfo;
266
267         if (addrinfo->ai_next) {
268                 if (addrinfo->ai_next->ai_family == AF_INET)
269                         ipv4 = addrinfo->ai_next;
270                 else if (addrinfo->ai_next->ai_family == AF_INET6)
271                         ipv6 = addrinfo->ai_next;
272         }
273
274         num = 0;
275         /* Take IPv6 first, since it might bind to IPv4 port too. */
276         if (ipv6) {
277                 if ((fds[num] = make_listen_fd(ipv6)) >= 0)
278                         num++;
279                 else
280                         ipv6 = NULL;
281         }
282         if (ipv4) {
283                 if ((fds[num] = make_listen_fd(ipv4)) >= 0)
284                         num++;
285                 else
286                         ipv4 = NULL;
287         }
288         if (num == 0)
289                 return -1;
290
291         return num;
292 }