]> git.ozlabs.org Git - ccan/blob - ccan/nfs/socket.c
talloc: spelling fix.
[ccan] / ccan / nfs / socket.c
1 /*
2    Copyright (C) by Ronnie Sahlberg <ronniesahlberg@gmail.com> 2010
3    
4    This program is free software; you can redistribute it and/or modify
5    it under the terms of the GNU General Public License as published by
6    the Free Software Foundation; either version 3 of the License, or
7    (at your option) any later version.
8    
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13    
14    You should have received a copy of the GNU General Public License
15    along with this program; if not, see <http://www.gnu.org/licenses/>.
16 */
17
18 #include <stdio.h>
19 #include <unistd.h>
20 #include <fcntl.h>
21 #include <poll.h>
22 #include <string.h>
23 #include <errno.h>
24 #include <time.h>
25 #include <rpc/xdr.h>
26 #include <arpa/inet.h>
27 #include <sys/ioctl.h>
28 #include "nfs.h"
29 #include "libnfs-raw.h"
30 #include "libnfs-private.h"
31 #include "dlinklist.h"
32
33 static void set_nonblocking(int fd)
34 {
35         unsigned v;
36         v = fcntl(fd, F_GETFL, 0);
37         fcntl(fd, F_SETFL, v | O_NONBLOCK);
38 }
39
40 int rpc_get_fd(struct rpc_context *rpc)
41 {
42         return rpc->fd;
43 }
44
45 int rpc_which_events(struct rpc_context *rpc)
46 {
47         int events = POLLIN;
48
49         if (rpc->is_connected == 0) {
50                 events |= POLLOUT;
51         }
52
53         if (rpc->outqueue) {
54                 events |= POLLOUT;
55         }
56         return events;
57 }
58
59 static int rpc_write_to_socket(struct rpc_context *rpc)
60 {
61         ssize_t count;
62
63         if (rpc == NULL) {
64                 printf("trying to write to socket for NULL context\n");
65                 return -1;
66         }
67         if (rpc->fd == -1) {
68                 printf("trying to write but not connected\n");
69                 return -2;
70         }
71
72         while (rpc->outqueue != NULL) {
73                 ssize_t total;
74
75                 total = rpc->outqueue->outdata.size;
76
77                 count = write(rpc->fd, rpc->outqueue->outdata.data + rpc->outqueue->written, total - rpc->outqueue->written);
78                 if (count == -1) {
79                         if (errno == EAGAIN || errno == EWOULDBLOCK) {
80                                 printf("socket would block, return from write to socket\n");
81                                 return 0;
82                         }
83                         printf("Error when writing to socket :%s(%d)\n", strerror(errno), errno);
84                         return -3;
85                 }
86
87                 rpc->outqueue->written += count;
88                 if (rpc->outqueue->written == total) {
89                         struct rpc_pdu *pdu = rpc->outqueue;
90
91                         DLIST_REMOVE(rpc->outqueue, pdu);
92                         DLIST_ADD_END(rpc->waitpdu, pdu, NULL);
93                 }
94         }
95         return 0;
96 }
97
98 static int rpc_read_from_socket(struct rpc_context *rpc)
99 {
100         int available;
101         int size;
102         unsigned char *buf;
103         ssize_t count;
104
105         if (ioctl(rpc->fd, FIONREAD, &available) != 0) {
106                 rpc_set_error(rpc, "Ioctl FIONREAD returned error : %d. Closing socket.", errno);
107                 return -1;
108         }
109         if (available == 0) {
110                 rpc_set_error(rpc, "Socket has been closed");
111                 return -2;
112         }
113         size = rpc->insize - rpc->inpos + available;
114         buf = malloc(size);
115         if (buf == NULL) {
116                 rpc_set_error(rpc, "Out of memory: failed to allocate %d bytes for input buffer. Closing socket.", size);
117                 return -3;
118         }
119         if (rpc->insize > rpc->inpos) {
120                 memcpy(buf, rpc->inbuf + rpc->inpos, rpc->insize - rpc->inpos);
121                 rpc->insize -= rpc->inpos;
122                 rpc->inpos   = 0;
123         }
124
125         count = read(rpc->fd, buf + rpc->insize, available);
126         if (count == -1) {
127                 if (errno == EINTR) {
128                         free(buf);
129                         buf = NULL;
130                         return 0;
131                 }
132                 rpc_set_error(rpc, "Read from socket failed, errno:%d. Closing socket.", errno);
133                 free(buf);
134                 buf = NULL;
135                 return -4;
136         }
137
138         if (rpc->inbuf != NULL) {
139                 free(rpc->inbuf);
140         }
141         rpc->inbuf   = (char *)buf;
142         rpc->insize += count;
143
144         while (1) {
145                 if (rpc->insize - rpc->inpos < 4) {
146                         return 0;
147                 }
148                 count = rpc_get_pdu_size(rpc->inbuf + rpc->inpos);
149                 if (rpc->insize + rpc->inpos < count) {
150                         return 0;
151                 }
152                 if (rpc_process_pdu(rpc, rpc->inbuf + rpc->inpos, count) != 0) {
153                         rpc_set_error(rpc, "Invalid/garbage pdu received from server. Closing socket");
154                         return -5;
155                 }
156                 rpc->inpos += count;
157                 if (rpc->inpos == rpc->insize) {
158                         free(rpc->inbuf);
159                         rpc->inbuf = NULL;
160                         rpc->insize = 0;
161                         rpc->inpos = 0;
162                 }
163         }
164         return 0;
165 }
166
167
168
169 int rpc_service(struct rpc_context *rpc, int revents)
170 {
171         if (revents & POLLERR) {
172                 printf("rpc_service: POLLERR, socket error\n");
173                 if (rpc->is_connected == 0) {
174                         rpc_set_error(rpc, "Failed to connect to server socket.");
175                 } else {
176                         rpc_set_error(rpc, "Socket closed with POLLERR");
177                 }
178                 rpc->connect_cb(rpc, RPC_STATUS_ERROR, rpc->error_string, rpc->connect_data);
179                 return -1;
180         }
181         if (revents & POLLHUP) {
182                 printf("rpc_service: POLLHUP, socket error\n");
183                 rpc_set_error(rpc, "Socket failed with POLLHUP");
184                 rpc->connect_cb(rpc, RPC_STATUS_ERROR, rpc->error_string, rpc->connect_data);
185                 return -2;
186         }
187
188         if (rpc->is_connected == 0 && rpc->fd != -1 && revents&POLLOUT) {
189                 rpc->is_connected = 1;
190                 rpc->connect_cb(rpc, RPC_STATUS_SUCCESS, NULL, rpc->connect_data);
191                 return 0;
192         }
193
194         if (revents & POLLOUT && rpc->outqueue != NULL) {
195                 if (rpc_write_to_socket(rpc) != 0) {
196                         printf("write to socket failed\n");
197                         return -3;
198                 }
199         }
200
201         if (revents & POLLIN) {
202                 if (rpc_read_from_socket(rpc) != 0) {
203                         rpc_disconnect(rpc, rpc_get_error(rpc));
204                         return -4;
205                 }
206         }
207
208         return 0;
209 }
210
211
212 int rpc_connect_async(struct rpc_context *rpc, const char *server, int port, int use_privileged_port, rpc_cb cb, void *private_data)
213 {
214         struct sockaddr_storage s;
215         struct sockaddr_in *sin = (struct sockaddr_in *)&s;
216         int socksize;
217
218         if (rpc->fd != -1) {
219                 rpc_set_error(rpc, "Trying to connect while already connected");
220                 printf("%s\n", rpc->error_string);
221                 return -1;
222         }
223
224         sin->sin_family = AF_INET;
225         sin->sin_port   = htons(port);
226         if (inet_pton(AF_INET, server, &sin->sin_addr) != 1) {
227                 rpc_set_error(rpc, "Not a valid server ip address");
228                 printf("%s\n", rpc->error_string);
229                 return -2;
230         }
231
232         switch (s.ss_family) {
233         case AF_INET:
234                 rpc->fd = socket(AF_INET, SOCK_STREAM, 0);
235                 socksize = sizeof(struct sockaddr_in);
236                 break;
237         }
238
239         if (rpc->fd == -1) {
240                 rpc_set_error(rpc, "Failed to open socket");
241                 printf("%s\n", rpc->error_string);
242                 return -3;
243         }
244
245         /* if we are root, try to find a privileged port to use (512 - 1023) */
246         if (geteuid() == 0 && use_privileged_port != 0) {
247                 struct sockaddr_storage ls;
248                 int ret, count = 0;
249                 static int local_port = 0;
250
251                 if (local_port == 0) {
252                         srandom(getpid() ^ time(NULL));
253                         local_port = random()%512 + 512;
254                 }
255
256                 do {
257                         count ++;
258                         if (local_port >= 1024) {
259                                 local_port = 512;
260                         }
261                         switch (s.ss_family) {
262                         case AF_INET:
263                                 bzero(&ls, socksize);
264                                 ((struct sockaddr_in *)&ls)->sin_family      = AF_INET;
265                                 ((struct sockaddr_in *)&ls)->sin_addr.s_addr = INADDR_ANY;
266                                 ((struct sockaddr_in *)&ls)->sin_port        = htons(local_port++);
267                                 break;
268                         }
269
270                         ret  = bind(rpc->fd, (struct sockaddr *)&ls, socksize);
271                 } while (ret != 0  && count < 50);
272         }
273
274         rpc->connect_cb  = cb;
275         rpc->connect_data = private_data;
276
277         set_nonblocking(rpc->fd);
278
279         if (connect(rpc->fd, (struct sockaddr *)&s, socksize) != 0 && errno != EINPROGRESS) {
280                 rpc_set_error(rpc, "connect() to server failed");
281                 printf("%s\n", rpc->error_string);
282                 return -4;
283         }
284
285         return 0;
286 }
287
288 int rpc_disconnect(struct rpc_context *rpc, char *error)
289 {
290         if (rpc->fd != -1) {
291                 close(rpc->fd);
292         }
293         rpc->fd  = -1;
294
295         rpc->is_connected = 0;
296
297         rpc_error_all_pdus(rpc, error);
298
299         return 0;
300 }