]> git.ozlabs.org Git - ccan/blob - ccan/net/net.c
net: make tests more robust.
[ccan] / ccan / net / net.c
1 /* Licensed under BSD-MIT - see LICENSE file for details */
2 #include <ccan/net/net.h>
3 #include <sys/types.h>
4 #include <sys/socket.h>
5 #include <poll.h>
6 #include <netdb.h>
7 #include <string.h>
8 #include <stdlib.h>
9 #include <unistd.h>
10 #include <fcntl.h>
11 #include <errno.h>
12 #include <stdbool.h>
13 #include <netinet/in.h>
14
15 struct addrinfo *net_client_lookup(const char *hostname,
16                                    const char *service,
17                                    int family,
18                                    int socktype)
19 {
20         struct addrinfo hints;
21         struct addrinfo *res;
22
23         memset(&hints, 0, sizeof(hints));
24         hints.ai_family = family;
25         hints.ai_socktype = socktype;
26         hints.ai_flags = 0;
27         hints.ai_protocol = 0;
28
29         if (getaddrinfo(hostname, service, &hints, &res) != 0)
30                 return NULL;
31
32         return res;
33 }
34
35 static bool set_nonblock(int fd, bool nonblock)
36 {
37         long flags;
38
39         flags = fcntl(fd, F_GETFL);
40         if (flags == -1)
41                 return false;
42
43         if (nonblock)
44                 flags |= O_NONBLOCK;
45         else
46                 flags &= ~(long)O_NONBLOCK;
47
48         return (fcntl(fd, F_SETFL, flags) == 0);
49 }
50
51 /* We only handle IPv4 and IPv6 */
52 #define MAX_PROTOS 2
53
54 static void remove_fd(struct pollfd pfd[],
55                       const struct addrinfo *addr[],
56                       socklen_t slen[],
57                       unsigned int *num,
58                       unsigned int i)
59 {
60         memmove(pfd + i, pfd + i + 1, (*num - i - 1) * sizeof(pfd[0]));
61         memmove(addr + i, addr + i + 1, (*num - i - 1) * sizeof(addr[0]));
62         memmove(slen + i, slen + i + 1, (*num - i - 1) * sizeof(slen[0]));
63         (*num)--;
64 }
65
66 int net_connect(const struct addrinfo *addrinfo)
67 {
68         int sockfd = -1, saved_errno;
69         unsigned int i, num;
70         const struct addrinfo *ipv4 = NULL, *ipv6 = NULL;
71         const struct addrinfo *addr[MAX_PROTOS];
72         socklen_t slen[MAX_PROTOS];
73         struct pollfd pfd[MAX_PROTOS];
74
75         for (; addrinfo; addrinfo = addrinfo->ai_next) {
76                 switch (addrinfo->ai_family) {
77                 case AF_INET:
78                         if (!ipv4)
79                                 ipv4 = addrinfo;
80                         break;
81                 case AF_INET6:
82                         if (!ipv6)
83                                 ipv6 = addrinfo;
84                         break;
85                 }
86         }
87
88         num = 0;
89         /* We give IPv6 a slight edge by connecting it first. */
90         if (ipv6) {
91                 addr[num] = ipv6;
92                 slen[num] = sizeof(struct sockaddr_in6);
93                 pfd[num].fd = socket(AF_INET6, ipv6->ai_socktype,
94                                      ipv6->ai_protocol);
95                 if (pfd[num].fd != -1)
96                         num++;
97         }
98         if (ipv4) {
99                 addr[num] = ipv4;
100                 slen[num] = sizeof(struct sockaddr_in);
101                 pfd[num].fd = socket(AF_INET, ipv4->ai_socktype,
102                                      ipv4->ai_protocol);
103                 if (pfd[num].fd != -1)
104                         num++;
105         }
106
107         for (i = 0; i < num; i++) {
108                 if (!set_nonblock(pfd[i].fd, true)) {
109                         remove_fd(pfd, addr, slen, &num, i--);
110                         continue;
111                 }
112                 /* Connect *can* be instant. */
113                 if (connect(pfd[i].fd, addr[i]->ai_addr, slen[i]) == 0)
114                         goto got_one;
115                 if (errno != EINPROGRESS) {
116                         /* Remove dead one. */
117                         remove_fd(pfd, addr, slen, &num, i--);
118                 }
119                 pfd[i].events = POLLOUT;
120         }
121
122         while (num && poll(pfd, num, -1) != -1) {
123                 for (i = 0; i < num; i++) {
124                         int err;
125                         socklen_t errlen = sizeof(err);
126                         if (!pfd[i].revents)
127                                 continue;
128                         if (getsockopt(pfd[i].fd, SOL_SOCKET, SO_ERROR, &err,
129                                        &errlen) != 0)
130                                 goto out;
131                         if (err == 0)
132                                 goto got_one;
133
134                         /* Remove dead one. */
135                         errno = err;
136                         remove_fd(pfd, addr, slen, &num, i--);
137                 }
138         }
139
140 got_one:
141         /* We don't want to hand them a non-blocking socket! */
142         if (set_nonblock(pfd[i].fd, false))
143                 sockfd = pfd[i].fd;
144
145 out:
146         saved_errno = errno;
147         for (i = 0; i < num; i++)
148                 if (pfd[i].fd != sockfd)
149                         close(pfd[i].fd);
150         errno = saved_errno;
151         return sockfd;
152 }