net: add server support.
[ccan] / ccan / net / net.c
1 /* Licensed under BSD-MIT - see LICENSE file for details */
2 #include <ccan/net/net.h>
3 #include <sys/types.h>
4 #include <sys/socket.h>
5 #include <poll.h>
6 #include <netdb.h>
7 #include <string.h>
8 #include <stdlib.h>
9 #include <unistd.h>
10 #include <fcntl.h>
11 #include <errno.h>
12 #include <stdbool.h>
13 #include <netinet/in.h>
14
15 struct addrinfo *net_client_lookup(const char *hostname,
16                                    const char *service,
17                                    int family,
18                                    int socktype)
19 {
20         struct addrinfo hints;
21         struct addrinfo *res;
22
23         memset(&hints, 0, sizeof(hints));
24         hints.ai_family = family;
25         hints.ai_socktype = socktype;
26         hints.ai_flags = 0;
27         hints.ai_protocol = 0;
28
29         if (getaddrinfo(hostname, service, &hints, &res) != 0)
30                 return NULL;
31
32         return res;
33 }
34
35 static bool set_nonblock(int fd, bool nonblock)
36 {
37         long flags;
38
39         flags = fcntl(fd, F_GETFL);
40         if (flags == -1)
41                 return false;
42
43         if (nonblock)
44                 flags |= O_NONBLOCK;
45         else
46                 flags &= ~(long)O_NONBLOCK;
47
48         return (fcntl(fd, F_SETFL, flags) == 0);
49 }
50
51 /* We only handle IPv4 and IPv6 */
52 #define MAX_PROTOS 2
53
54 static void remove_fd(struct pollfd pfd[],
55                       const struct addrinfo *addr[],
56                       socklen_t slen[],
57                       unsigned int *num,
58                       unsigned int i)
59 {
60         memmove(pfd + i, pfd + i + 1, (*num - i - 1) * sizeof(pfd[0]));
61         memmove(addr + i, addr + i + 1, (*num - i - 1) * sizeof(addr[0]));
62         memmove(slen + i, slen + i + 1, (*num - i - 1) * sizeof(slen[0]));
63         (*num)--;
64 }
65
66 int net_connect(const struct addrinfo *addrinfo)
67 {
68         int sockfd = -1, saved_errno;
69         unsigned int i, num;
70         const struct addrinfo *ipv4 = NULL, *ipv6 = NULL;
71         const struct addrinfo *addr[MAX_PROTOS];
72         socklen_t slen[MAX_PROTOS];
73         struct pollfd pfd[MAX_PROTOS];
74
75         for (; addrinfo; addrinfo = addrinfo->ai_next) {
76                 switch (addrinfo->ai_family) {
77                 case AF_INET:
78                         if (!ipv4)
79                                 ipv4 = addrinfo;
80                         break;
81                 case AF_INET6:
82                         if (!ipv6)
83                                 ipv6 = addrinfo;
84                         break;
85                 }
86         }
87
88         num = 0;
89         /* We give IPv6 a slight edge by connecting it first. */
90         if (ipv6) {
91                 addr[num] = ipv6;
92                 slen[num] = sizeof(struct sockaddr_in6);
93                 pfd[num].fd = socket(AF_INET6, ipv6->ai_socktype,
94                                      ipv6->ai_protocol);
95                 if (pfd[num].fd != -1)
96                         num++;
97         }
98         if (ipv4) {
99                 addr[num] = ipv4;
100                 slen[num] = sizeof(struct sockaddr_in);
101                 pfd[num].fd = socket(AF_INET, ipv4->ai_socktype,
102                                      ipv4->ai_protocol);
103                 if (pfd[num].fd != -1)
104                         num++;
105         }
106
107         for (i = 0; i < num; i++) {
108                 if (!set_nonblock(pfd[i].fd, true)) {
109                         remove_fd(pfd, addr, slen, &num, i--);
110                         continue;
111                 }
112                 /* Connect *can* be instant. */
113                 if (connect(pfd[i].fd, addr[i]->ai_addr, slen[i]) == 0)
114                         goto got_one;
115                 if (errno != EINPROGRESS) {
116                         /* Remove dead one. */
117                         remove_fd(pfd, addr, slen, &num, i--);
118                 }
119                 pfd[i].events = POLLOUT;
120         }
121
122         while (num && poll(pfd, num, -1) != -1) {
123                 for (i = 0; i < num; i++) {
124                         int err;
125                         socklen_t errlen = sizeof(err);
126                         if (!pfd[i].revents)
127                                 continue;
128                         if (getsockopt(pfd[i].fd, SOL_SOCKET, SO_ERROR, &err,
129                                        &errlen) != 0)
130                                 goto out;
131                         if (err == 0)
132                                 goto got_one;
133
134                         /* Remove dead one. */
135                         errno = err;
136                         remove_fd(pfd, addr, slen, &num, i--);
137                 }
138         }
139
140 got_one:
141         /* We don't want to hand them a non-blocking socket! */
142         if (set_nonblock(pfd[i].fd, false))
143                 sockfd = pfd[i].fd;
144
145 out:
146         saved_errno = errno;
147         for (i = 0; i < num; i++)
148                 if (pfd[i].fd != sockfd)
149                         close(pfd[i].fd);
150         errno = saved_errno;
151         return sockfd;
152 }
153
154 struct addrinfo *net_server_lookup(const char *service,
155                                    int family,
156                                    int socktype)
157 {
158         struct addrinfo *res, hints;
159
160         memset(&hints, 0, sizeof(hints));
161         hints.ai_family = family;
162         hints.ai_socktype = socktype;
163         hints.ai_flags = AI_PASSIVE;
164         hints.ai_protocol = 0;
165
166         if (getaddrinfo(NULL, service, &hints, &res) != 0)
167                 return NULL;
168
169         return res;
170 }
171
172 static bool should_listen(const struct addrinfo *addrinfo)
173 {
174 #ifdef SOCK_SEQPACKET
175         if (addrinfo->ai_socktype == SOCK_SEQPACKET)
176                 return true;
177 #endif
178         return (addrinfo->ai_socktype == SOCK_STREAM);
179 }
180
181 static int make_listen_fd(const struct addrinfo *addrinfo)
182 {
183         int saved_errno, fd, on = 1;
184
185         fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
186                     addrinfo->ai_protocol);
187         if (fd < 0)
188                 return -1;
189
190         setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
191         if (bind(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0)
192                 goto fail;
193
194         if (should_listen(addrinfo) && listen(fd, 5) != 0)
195                 goto fail;
196         return fd;
197
198 fail:
199         saved_errno = errno;
200         close(fd);
201         errno = saved_errno;
202         return -1;
203 }
204
205 int net_bind(const struct addrinfo *addrinfo, int fds[2])
206 {
207         const struct addrinfo *ipv6, *ipv4;
208         unsigned int num;
209
210         if (addrinfo->ai_family == AF_INET)
211                 ipv4 = addrinfo;
212         else if (addrinfo->ai_family == AF_INET6)
213                 ipv6 = addrinfo;
214
215         if (addrinfo->ai_next) {
216                 if (addrinfo->ai_next->ai_family == AF_INET)
217                         ipv4 = addrinfo->ai_next;
218                 else if (addrinfo->ai_next->ai_family == AF_INET6)
219                         ipv6 = addrinfo->ai_next;
220         }
221
222         num = 0;
223         /* Take IPv6 first, since it might bind to IPv4 port too. */
224         if (ipv6) {
225                 if ((fds[num] = make_listen_fd(ipv6)) >= 0)
226                         num++;
227                 else
228                         ipv6 = NULL;
229         }
230         if (ipv4) {
231                 if ((fds[num] = make_listen_fd(ipv4)) >= 0)
232                         num++;
233                 else
234                         ipv4 = NULL;
235         }
236         if (num == 0)
237                 return -1;
238
239         return num;
240 }