Browse Source

io: io_set_alloc()

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
Rusty Russell 12 years ago
parent
commit
e40f5c50a7
5 changed files with 289 additions and 16 deletions
  1. 7 0
      ccan/io/backend.h
  2. 21 8
      ccan/io/io.c
  3. 13 0
      ccan/io/io.h
  4. 8 8
      ccan/io/poll.c
  5. 240 0
      ccan/io/test/run-set_alloc.c

+ 7 - 0
ccan/io/backend.h

@@ -4,6 +4,13 @@
 #include <stdbool.h>
 #include <stdbool.h>
 #include <ccan/timer/timer.h>
 #include <ccan/timer/timer.h>
 
 
+struct io_alloc {
+	void *(*alloc)(size_t size);
+	void *(*realloc)(void *ptr, size_t size);
+	void (*free)(void *ptr);
+};
+extern struct io_alloc io_alloc;
+
 struct fd {
 struct fd {
 	int fd;
 	int fd;
 	bool listener;
 	bool listener;

+ 21 - 8
ccan/io/io.c

@@ -14,6 +14,10 @@
 
 
 void *io_loop_return;
 void *io_loop_return;
 
 
+struct io_alloc io_alloc = {
+	malloc, realloc, free
+};
+
 #ifdef DEBUG
 #ifdef DEBUG
 /* Set to skip the next plan. */
 /* Set to skip the next plan. */
 bool io_plan_nodebug;
 bool io_plan_nodebug;
@@ -125,7 +129,7 @@ struct io_listener *io_new_listener_(int fd,
 				     void (*init)(int fd, void *arg),
 				     void (*init)(int fd, void *arg),
 				     void *arg)
 				     void *arg)
 {
 {
-	struct io_listener *l = malloc(sizeof(*l));
+	struct io_listener *l = io_alloc.alloc(sizeof(*l));
 
 
 	if (!l)
 	if (!l)
 		return NULL;
 		return NULL;
@@ -135,7 +139,7 @@ struct io_listener *io_new_listener_(int fd,
 	l->init = init;
 	l->init = init;
 	l->arg = arg;
 	l->arg = arg;
 	if (!add_listener(l)) {
 	if (!add_listener(l)) {
-		free(l);
+		io_alloc.free(l);
 		return NULL;
 		return NULL;
 	}
 	}
 	return l;
 	return l;
@@ -145,12 +149,12 @@ void io_close_listener(struct io_listener *l)
 {
 {
 	close(l->fd.fd);
 	close(l->fd.fd);
 	del_listener(l);
 	del_listener(l);
-	free(l);
+	io_alloc.free(l);
 }
 }
 
 
 struct io_conn *io_new_conn_(int fd, struct io_plan plan)
 struct io_conn *io_new_conn_(int fd, struct io_plan plan)
 {
 {
-	struct io_conn *conn = malloc(sizeof(*conn));
+	struct io_conn *conn = io_alloc.alloc(sizeof(*conn));
 
 
 	io_plan_debug_again();
 	io_plan_debug_again();
 
 
@@ -165,7 +169,7 @@ struct io_conn *io_new_conn_(int fd, struct io_plan plan)
 	conn->duplex = NULL;
 	conn->duplex = NULL;
 	conn->timeout = NULL;
 	conn->timeout = NULL;
 	if (!add_conn(conn)) {
 	if (!add_conn(conn)) {
-		free(conn);
+		io_alloc.free(conn);
 		return NULL;
 		return NULL;
 	}
 	}
 	return conn;
 	return conn;
@@ -187,7 +191,7 @@ struct io_conn *io_duplex_(struct io_conn *old, struct io_plan plan)
 
 
 	assert(!old->duplex);
 	assert(!old->duplex);
 
 
-	conn = malloc(sizeof(*conn));
+	conn = io_alloc.alloc(sizeof(*conn));
 	if (!conn)
 	if (!conn)
 		return NULL;
 		return NULL;
 
 
@@ -199,7 +203,7 @@ struct io_conn *io_duplex_(struct io_conn *old, struct io_plan plan)
 	conn->finish_arg = NULL;
 	conn->finish_arg = NULL;
 	conn->timeout = NULL;
 	conn->timeout = NULL;
 	if (!add_duplex(conn)) {
 	if (!add_duplex(conn)) {
-		free(conn);
+		io_alloc.free(conn);
 		return NULL;
 		return NULL;
 	}
 	}
 	old->duplex = conn;
 	old->duplex = conn;
@@ -212,7 +216,7 @@ bool io_timeout_(struct io_conn *conn, struct timespec ts,
 	assert(cb);
 	assert(cb);
 
 
 	if (!conn->timeout) {
 	if (!conn->timeout) {
-		conn->timeout = malloc(sizeof(*conn->timeout));
+		conn->timeout = io_alloc.alloc(sizeof(*conn->timeout));
 		if (!conn->timeout)
 		if (!conn->timeout)
 			return false;
 			return false;
 	} else
 	} else
@@ -467,3 +471,12 @@ struct io_plan io_break_(void *ret, struct io_plan plan)
 
 
 	return plan;
 	return plan;
 }
 }
+
+void io_set_alloc(void *(*allocfn)(size_t size),
+		  void *(*reallocfn)(void *ptr, size_t size),
+		  void (*freefn)(void *ptr))
+{
+	io_alloc.alloc = allocfn;
+	io_alloc.realloc = reallocfn;
+	io_alloc.free = freefn;
+}

+ 13 - 0
ccan/io/io.h

@@ -490,4 +490,17 @@ struct io_plan io_close_cb(struct io_conn *, void *unused);
  *	io_loop();
  *	io_loop();
  */
  */
 void *io_loop(void);
 void *io_loop(void);
+
+/**
+ * io_set_alloc - set alloc/realloc/free function for io to use.
+ * @allocfn: allocator function
+ * @reallocfn: reallocator function, ptr may be NULL, size never 0.
+ * @freefn: free function
+ *
+ * By default io uses malloc/realloc/free, and returns NULL if they fail.
+ * You can set your own variants here.
+ */
+void io_set_alloc(void *(*allocfn)(size_t size),
+		  void *(*reallocfn)(void *ptr, size_t size),
+		  void (*freefn)(void *ptr));
 #endif /* CCAN_IO_H */
 #endif /* CCAN_IO_H */

+ 8 - 8
ccan/io/poll.c

@@ -28,7 +28,7 @@ static void io_loop_exit(void)
 		while (free_later) {
 		while (free_later) {
 			struct io_conn *c = free_later;
 			struct io_conn *c = free_later;
 			free_later = c->finish_arg;
 			free_later = c->finish_arg;
-			free(c);
+			io_alloc.free(c);
 		}
 		}
 	}
 	}
 }
 }
@@ -42,7 +42,7 @@ static void free_conn(struct io_conn *conn)
 		conn->finish_arg = free_later;
 		conn->finish_arg = free_later;
 		free_later = conn;
 		free_later = conn;
 	} else
 	} else
-		free(conn);
+		io_alloc.free(conn);
 }
 }
 #else
 #else
 static void io_loop_enter(void)
 static void io_loop_enter(void)
@@ -53,7 +53,7 @@ static void io_loop_exit(void)
 }
 }
 static void free_conn(struct io_conn *conn)
 static void free_conn(struct io_conn *conn)
 {
 {
-	free(conn);
+	io_alloc.free(conn);
 }
 }
 #endif
 #endif
 
 
@@ -64,11 +64,11 @@ static bool add_fd(struct fd *fd, short events)
 		struct fd **newfds;
 		struct fd **newfds;
 		size_t num = max_fds ? max_fds * 2 : 8;
 		size_t num = max_fds ? max_fds * 2 : 8;
 
 
-		newpollfds = realloc(pollfds, sizeof(*newpollfds) * num);
+		newpollfds = io_alloc.realloc(pollfds, sizeof(*newpollfds)*num);
 		if (!newpollfds)
 		if (!newpollfds)
 			return false;
 			return false;
 		pollfds = newpollfds;
 		pollfds = newpollfds;
-		newfds = realloc(fds, sizeof(*newfds) * num);
+		newfds = io_alloc.realloc(fds, sizeof(*newfds) * num);
 		if (!newfds)
 		if (!newfds)
 			return false;
 			return false;
 		fds = newfds;
 		fds = newfds;
@@ -107,8 +107,8 @@ static void del_fd(struct fd *fd)
 		fds[n]->backend_info = n;
 		fds[n]->backend_info = n;
 	} else if (num_fds == 1) {
 	} else if (num_fds == 1) {
 		/* Free everything when no more fds. */
 		/* Free everything when no more fds. */
-		free(pollfds);
-		free(fds);
+		io_alloc.free(pollfds);
+		io_alloc.free(fds);
 		pollfds = NULL;
 		pollfds = NULL;
 		fds = NULL;
 		fds = NULL;
 		max_fds = 0;
 		max_fds = 0;
@@ -181,7 +181,7 @@ void backend_del_conn(struct io_conn *conn)
 	}
 	}
 	if (timeout_active(conn))
 	if (timeout_active(conn))
 		backend_del_timeout(conn);
 		backend_del_timeout(conn);
-	free(conn->timeout);
+	io_alloc.free(conn->timeout);
 	if (conn->duplex) {
 	if (conn->duplex) {
 		/* In case fds[] pointed to the other one. */
 		/* In case fds[] pointed to the other one. */
 		fds[conn->fd.backend_info] = &conn->duplex->fd;
 		fds[conn->fd.backend_info] = &conn->duplex->fd;

+ 240 - 0
ccan/io/test/run-set_alloc.c

@@ -0,0 +1,240 @@
+#include <ccan/tap/tap.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <signal.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+
+/* Make sure we override these! */
+static void *no_malloc(size_t size)
+{
+	abort();
+}
+static void *no_realloc(void *p, size_t size)
+{
+	abort();
+}
+static void no_free(void *p)
+{
+	abort();
+}
+#define malloc no_malloc
+#define realloc no_realloc
+#define free no_free
+
+#include <ccan/io/poll.c>
+#include <ccan/io/io.c>
+
+#undef malloc
+#undef realloc
+#undef free
+
+static unsigned int alloc_count, realloc_count, free_count;
+static void *ptrs[100];
+
+static void **find_ptr(void *p)
+{
+	unsigned int i;
+
+	for (i = 0; i < 100; i++)
+		if (ptrs[i] == p)
+			return ptrs + i;
+	return NULL;
+}
+
+static void *allocfn(size_t size)
+{
+	alloc_count++;
+	return *find_ptr(NULL) = malloc(size);
+}
+
+static void *reallocfn(void *ptr, size_t size)
+{
+	realloc_count++;
+	if (!ptr)
+		alloc_count++;
+
+	return *find_ptr(ptr) = realloc(ptr, size);
+}
+
+static void freefn(void *ptr)
+{
+	free_count++;
+	free(ptr);
+	*find_ptr(ptr) = NULL;
+}
+
+#ifndef PORT
+#define PORT "65015"
+#endif
+
+struct data {
+	int state;
+	int timeout_usec;
+	bool timed_out;
+	char buf[4];
+};
+
+
+static struct io_plan no_timeout(struct io_conn *conn, struct data *d)
+{
+	ok1(d->state == 1);
+	d->state++;
+	return io_close();
+}
+
+static struct io_plan timeout(struct io_conn *conn, struct data *d)
+{
+	ok1(d->state == 1);
+	d->state++;
+	d->timed_out = true;
+	return io_close();
+}
+
+static void finish_ok(struct io_conn *conn, struct data *d)
+{
+	ok1(d->state == 2);
+	d->state++;
+	io_break(d, io_idle());
+}
+
+static void init_conn(int fd, struct data *d)
+{
+	struct io_conn *conn;
+
+	ok1(d->state == 0);
+	d->state++;
+
+	conn = io_new_conn(fd, io_read(d->buf, sizeof(d->buf), no_timeout, d));
+	io_set_finish(conn, finish_ok, d);
+	io_timeout(conn, time_from_usec(d->timeout_usec), timeout, d);
+}
+
+static int make_listen_fd(const char *port, struct addrinfo **info)
+{
+	int fd, on = 1;
+	struct addrinfo *addrinfo, hints;
+
+	memset(&hints, 0, sizeof(hints));
+	hints.ai_family = AF_UNSPEC;
+	hints.ai_socktype = SOCK_STREAM;
+	hints.ai_flags = AI_PASSIVE;
+	hints.ai_protocol = 0;
+
+	if (getaddrinfo(NULL, port, &hints, &addrinfo) != 0)
+		return -1;
+
+	fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
+		    addrinfo->ai_protocol);
+	if (fd < 0)
+		return -1;
+
+	setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
+	if (bind(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0) {
+		close(fd);
+		return -1;
+	}
+	if (listen(fd, 1) != 0) {
+		close(fd);
+		return -1;
+	}
+	*info = addrinfo;
+	return fd;
+}
+
+int main(void)
+{
+	struct data *d = allocfn(sizeof(*d));
+	struct addrinfo *addrinfo;
+	struct io_listener *l;
+	int fd, status;
+
+	io_set_alloc(allocfn, reallocfn, freefn);
+
+	/* This is how many tests you plan to run */
+	plan_tests(25);
+	d->state = 0;
+	d->timed_out = false;
+	d->timeout_usec = 100000;
+	fd = make_listen_fd(PORT, &addrinfo);
+	ok1(fd >= 0);
+	l = io_new_listener(fd, init_conn, d);
+	ok1(l);
+	fflush(stdout);
+
+	if (!fork()) {
+		int i;
+
+		io_close_listener(l);
+		fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
+			    addrinfo->ai_protocol);
+		if (fd < 0)
+			exit(1);
+		if (connect(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0)
+			exit(2);
+		signal(SIGPIPE, SIG_IGN);
+		usleep(500000);
+		for (i = 0; i < strlen("hellothere"); i++) {
+			if (write(fd, "hellothere" + i, 1) != 1)
+				break;
+		}
+		close(fd);
+		freeaddrinfo(addrinfo);
+		free(d);
+		exit(i);
+	}
+	ok1(io_loop() == d);
+	ok1(d->state == 3);
+	ok1(d->timed_out == true);
+	ok1(wait(&status));
+	ok1(WIFEXITED(status));
+	ok1(WEXITSTATUS(status) < sizeof(d->buf));
+
+	/* This one shouldn't time out. */
+	d->state = 0;
+	d->timed_out = false;
+	d->timeout_usec = 500000;
+	fflush(stdout);
+
+	if (!fork()) {
+		int i;
+
+		io_close_listener(l);
+		fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
+			    addrinfo->ai_protocol);
+		if (fd < 0)
+			exit(1);
+		if (connect(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0)
+			exit(2);
+		signal(SIGPIPE, SIG_IGN);
+		usleep(100000);
+		for (i = 0; i < strlen("hellothere"); i++) {
+			if (write(fd, "hellothere" + i, 1) != 1)
+				break;
+		}
+		close(fd);
+		freeaddrinfo(addrinfo);
+		free(d);
+		exit(i);
+	}
+	ok1(io_loop() == d);
+	ok1(d->state == 3);
+	ok1(d->timed_out == false);
+	ok1(wait(&status));
+	ok1(WIFEXITED(status));
+	ok1(WEXITSTATUS(status) >= sizeof(d->buf));
+
+	io_close_listener(l);
+	freeaddrinfo(addrinfo);
+
+	/* We should have tested each one at least once! */
+	ok1(realloc_count);
+	ok1(alloc_count);
+	ok1(free_count);
+
+	ok1(free_count < alloc_count);
+	freefn(d);
+	ok1(free_count == alloc_count);
+
+	return exit_status();
+}