io: io_set_alloc()
authorRusty Russell <rusty@rustcorp.com.au>
Mon, 21 Oct 2013 05:10:02 +0000 (15:40 +1030)
committerRusty Russell <rusty@rustcorp.com.au>
Mon, 21 Oct 2013 05:10:02 +0000 (15:40 +1030)
Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
ccan/io/backend.h
ccan/io/io.c
ccan/io/io.h
ccan/io/poll.c
ccan/io/test/run-set_alloc.c [new file with mode: 0644]

index 30a338f77a73958eaf83acad62eb87b027d27025..77d51dda9bf1825a6e02a6be85b3a0f3184c5804 100644 (file)
@@ -4,6 +4,13 @@
 #include <stdbool.h>
 #include <ccan/timer/timer.h>
 
+struct io_alloc {
+       void *(*alloc)(size_t size);
+       void *(*realloc)(void *ptr, size_t size);
+       void (*free)(void *ptr);
+};
+extern struct io_alloc io_alloc;
+
 struct fd {
        int fd;
        bool listener;
index 34b7cb0ce7ca6e739246a33b3a457aafb3e5a5b8..e2dd44430b79d0f3a0977a92861cb72a86f56df6 100644 (file)
 
 void *io_loop_return;
 
+struct io_alloc io_alloc = {
+       malloc, realloc, free
+};
+
 #ifdef DEBUG
 /* Set to skip the next plan. */
 bool io_plan_nodebug;
@@ -125,7 +129,7 @@ struct io_listener *io_new_listener_(int fd,
                                     void (*init)(int fd, void *arg),
                                     void *arg)
 {
-       struct io_listener *l = malloc(sizeof(*l));
+       struct io_listener *l = io_alloc.alloc(sizeof(*l));
 
        if (!l)
                return NULL;
@@ -135,7 +139,7 @@ struct io_listener *io_new_listener_(int fd,
        l->init = init;
        l->arg = arg;
        if (!add_listener(l)) {
-               free(l);
+               io_alloc.free(l);
                return NULL;
        }
        return l;
@@ -145,12 +149,12 @@ void io_close_listener(struct io_listener *l)
 {
        close(l->fd.fd);
        del_listener(l);
-       free(l);
+       io_alloc.free(l);
 }
 
 struct io_conn *io_new_conn_(int fd, struct io_plan plan)
 {
-       struct io_conn *conn = malloc(sizeof(*conn));
+       struct io_conn *conn = io_alloc.alloc(sizeof(*conn));
 
        io_plan_debug_again();
 
@@ -165,7 +169,7 @@ struct io_conn *io_new_conn_(int fd, struct io_plan plan)
        conn->duplex = NULL;
        conn->timeout = NULL;
        if (!add_conn(conn)) {
-               free(conn);
+               io_alloc.free(conn);
                return NULL;
        }
        return conn;
@@ -187,7 +191,7 @@ struct io_conn *io_duplex_(struct io_conn *old, struct io_plan plan)
 
        assert(!old->duplex);
 
-       conn = malloc(sizeof(*conn));
+       conn = io_alloc.alloc(sizeof(*conn));
        if (!conn)
                return NULL;
 
@@ -199,7 +203,7 @@ struct io_conn *io_duplex_(struct io_conn *old, struct io_plan plan)
        conn->finish_arg = NULL;
        conn->timeout = NULL;
        if (!add_duplex(conn)) {
-               free(conn);
+               io_alloc.free(conn);
                return NULL;
        }
        old->duplex = conn;
@@ -212,7 +216,7 @@ bool io_timeout_(struct io_conn *conn, struct timespec ts,
        assert(cb);
 
        if (!conn->timeout) {
-               conn->timeout = malloc(sizeof(*conn->timeout));
+               conn->timeout = io_alloc.alloc(sizeof(*conn->timeout));
                if (!conn->timeout)
                        return false;
        } else
@@ -467,3 +471,12 @@ struct io_plan io_break_(void *ret, struct io_plan plan)
 
        return plan;
 }
+
+void io_set_alloc(void *(*allocfn)(size_t size),
+                 void *(*reallocfn)(void *ptr, size_t size),
+                 void (*freefn)(void *ptr))
+{
+       io_alloc.alloc = allocfn;
+       io_alloc.realloc = reallocfn;
+       io_alloc.free = freefn;
+}
index b5ffdd243bfd8a1710e551741261bb9159923cff..067a69c1dec441a752508dfe50ee8f83b8421466 100644 (file)
@@ -490,4 +490,17 @@ struct io_plan io_close_cb(struct io_conn *, void *unused);
  *     io_loop();
  */
 void *io_loop(void);
+
+/**
+ * io_set_alloc - set alloc/realloc/free function for io to use.
+ * @allocfn: allocator function
+ * @reallocfn: reallocator function, ptr may be NULL, size never 0.
+ * @freefn: free function
+ *
+ * By default io uses malloc/realloc/free, and returns NULL if they fail.
+ * You can set your own variants here.
+ */
+void io_set_alloc(void *(*allocfn)(size_t size),
+                 void *(*reallocfn)(void *ptr, size_t size),
+                 void (*freefn)(void *ptr));
 #endif /* CCAN_IO_H */
index f15644025cf3fdb756e2e1e20bb53935f02f7b64..0078fc658411c1d50e01a5c3e9eed278f5214a83 100644 (file)
@@ -28,7 +28,7 @@ static void io_loop_exit(void)
                while (free_later) {
                        struct io_conn *c = free_later;
                        free_later = c->finish_arg;
-                       free(c);
+                       io_alloc.free(c);
                }
        }
 }
@@ -42,7 +42,7 @@ static void free_conn(struct io_conn *conn)
                conn->finish_arg = free_later;
                free_later = conn;
        } else
-               free(conn);
+               io_alloc.free(conn);
 }
 #else
 static void io_loop_enter(void)
@@ -53,7 +53,7 @@ static void io_loop_exit(void)
 }
 static void free_conn(struct io_conn *conn)
 {
-       free(conn);
+       io_alloc.free(conn);
 }
 #endif
 
@@ -64,11 +64,11 @@ static bool add_fd(struct fd *fd, short events)
                struct fd **newfds;
                size_t num = max_fds ? max_fds * 2 : 8;
 
-               newpollfds = realloc(pollfds, sizeof(*newpollfds) * num);
+               newpollfds = io_alloc.realloc(pollfds, sizeof(*newpollfds)*num);
                if (!newpollfds)
                        return false;
                pollfds = newpollfds;
-               newfds = realloc(fds, sizeof(*newfds) * num);
+               newfds = io_alloc.realloc(fds, sizeof(*newfds) * num);
                if (!newfds)
                        return false;
                fds = newfds;
@@ -107,8 +107,8 @@ static void del_fd(struct fd *fd)
                fds[n]->backend_info = n;
        } else if (num_fds == 1) {
                /* Free everything when no more fds. */
-               free(pollfds);
-               free(fds);
+               io_alloc.free(pollfds);
+               io_alloc.free(fds);
                pollfds = NULL;
                fds = NULL;
                max_fds = 0;
@@ -181,7 +181,7 @@ void backend_del_conn(struct io_conn *conn)
        }
        if (timeout_active(conn))
                backend_del_timeout(conn);
-       free(conn->timeout);
+       io_alloc.free(conn->timeout);
        if (conn->duplex) {
                /* In case fds[] pointed to the other one. */
                fds[conn->fd.backend_info] = &conn->duplex->fd;
diff --git a/ccan/io/test/run-set_alloc.c b/ccan/io/test/run-set_alloc.c
new file mode 100644 (file)
index 0000000..fd0c83a
--- /dev/null
@@ -0,0 +1,240 @@
+#include <ccan/tap/tap.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <signal.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+
+/* Make sure we override these! */
+static void *no_malloc(size_t size)
+{
+       abort();
+}
+static void *no_realloc(void *p, size_t size)
+{
+       abort();
+}
+static void no_free(void *p)
+{
+       abort();
+}
+#define malloc no_malloc
+#define realloc no_realloc
+#define free no_free
+
+#include <ccan/io/poll.c>
+#include <ccan/io/io.c>
+
+#undef malloc
+#undef realloc
+#undef free
+
+static unsigned int alloc_count, realloc_count, free_count;
+static void *ptrs[100];
+
+static void **find_ptr(void *p)
+{
+       unsigned int i;
+
+       for (i = 0; i < 100; i++)
+               if (ptrs[i] == p)
+                       return ptrs + i;
+       return NULL;
+}
+
+static void *allocfn(size_t size)
+{
+       alloc_count++;
+       return *find_ptr(NULL) = malloc(size);
+}
+
+static void *reallocfn(void *ptr, size_t size)
+{
+       realloc_count++;
+       if (!ptr)
+               alloc_count++;
+
+       return *find_ptr(ptr) = realloc(ptr, size);
+}
+
+static void freefn(void *ptr)
+{
+       free_count++;
+       free(ptr);
+       *find_ptr(ptr) = NULL;
+}
+
+#ifndef PORT
+#define PORT "65015"
+#endif
+
+struct data {
+       int state;
+       int timeout_usec;
+       bool timed_out;
+       char buf[4];
+};
+
+
+static struct io_plan no_timeout(struct io_conn *conn, struct data *d)
+{
+       ok1(d->state == 1);
+       d->state++;
+       return io_close();
+}
+
+static struct io_plan timeout(struct io_conn *conn, struct data *d)
+{
+       ok1(d->state == 1);
+       d->state++;
+       d->timed_out = true;
+       return io_close();
+}
+
+static void finish_ok(struct io_conn *conn, struct data *d)
+{
+       ok1(d->state == 2);
+       d->state++;
+       io_break(d, io_idle());
+}
+
+static void init_conn(int fd, struct data *d)
+{
+       struct io_conn *conn;
+
+       ok1(d->state == 0);
+       d->state++;
+
+       conn = io_new_conn(fd, io_read(d->buf, sizeof(d->buf), no_timeout, d));
+       io_set_finish(conn, finish_ok, d);
+       io_timeout(conn, time_from_usec(d->timeout_usec), timeout, d);
+}
+
+static int make_listen_fd(const char *port, struct addrinfo **info)
+{
+       int fd, on = 1;
+       struct addrinfo *addrinfo, hints;
+
+       memset(&hints, 0, sizeof(hints));
+       hints.ai_family = AF_UNSPEC;
+       hints.ai_socktype = SOCK_STREAM;
+       hints.ai_flags = AI_PASSIVE;
+       hints.ai_protocol = 0;
+
+       if (getaddrinfo(NULL, port, &hints, &addrinfo) != 0)
+               return -1;
+
+       fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
+                   addrinfo->ai_protocol);
+       if (fd < 0)
+               return -1;
+
+       setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
+       if (bind(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0) {
+               close(fd);
+               return -1;
+       }
+       if (listen(fd, 1) != 0) {
+               close(fd);
+               return -1;
+       }
+       *info = addrinfo;
+       return fd;
+}
+
+int main(void)
+{
+       struct data *d = allocfn(sizeof(*d));
+       struct addrinfo *addrinfo;
+       struct io_listener *l;
+       int fd, status;
+
+       io_set_alloc(allocfn, reallocfn, freefn);
+
+       /* This is how many tests you plan to run */
+       plan_tests(25);
+       d->state = 0;
+       d->timed_out = false;
+       d->timeout_usec = 100000;
+       fd = make_listen_fd(PORT, &addrinfo);
+       ok1(fd >= 0);
+       l = io_new_listener(fd, init_conn, d);
+       ok1(l);
+       fflush(stdout);
+
+       if (!fork()) {
+               int i;
+
+               io_close_listener(l);
+               fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
+                           addrinfo->ai_protocol);
+               if (fd < 0)
+                       exit(1);
+               if (connect(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0)
+                       exit(2);
+               signal(SIGPIPE, SIG_IGN);
+               usleep(500000);
+               for (i = 0; i < strlen("hellothere"); i++) {
+                       if (write(fd, "hellothere" + i, 1) != 1)
+                               break;
+               }
+               close(fd);
+               freeaddrinfo(addrinfo);
+               free(d);
+               exit(i);
+       }
+       ok1(io_loop() == d);
+       ok1(d->state == 3);
+       ok1(d->timed_out == true);
+       ok1(wait(&status));
+       ok1(WIFEXITED(status));
+       ok1(WEXITSTATUS(status) < sizeof(d->buf));
+
+       /* This one shouldn't time out. */
+       d->state = 0;
+       d->timed_out = false;
+       d->timeout_usec = 500000;
+       fflush(stdout);
+
+       if (!fork()) {
+               int i;
+
+               io_close_listener(l);
+               fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
+                           addrinfo->ai_protocol);
+               if (fd < 0)
+                       exit(1);
+               if (connect(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0)
+                       exit(2);
+               signal(SIGPIPE, SIG_IGN);
+               usleep(100000);
+               for (i = 0; i < strlen("hellothere"); i++) {
+                       if (write(fd, "hellothere" + i, 1) != 1)
+                               break;
+               }
+               close(fd);
+               freeaddrinfo(addrinfo);
+               free(d);
+               exit(i);
+       }
+       ok1(io_loop() == d);
+       ok1(d->state == 3);
+       ok1(d->timed_out == false);
+       ok1(wait(&status));
+       ok1(WIFEXITED(status));
+       ok1(WEXITSTATUS(status) >= sizeof(d->buf));
+
+       io_close_listener(l);
+       freeaddrinfo(addrinfo);
+
+       /* We should have tested each one at least once! */
+       ok1(realloc_count);
+       ok1(alloc_count);
+       ok1(free_count);
+
+       ok1(free_count < alloc_count);
+       freefn(d);
+       ok1(free_count == alloc_count);
+
+       return exit_status();
+}