Browse Source

io: io_close_taken_fd to steal fd from conn.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
Rusty Russell 9 years ago
parent
commit
17a81baf84
4 changed files with 127 additions and 4 deletions
  1. 0 1
      ccan/io/backend.h
  2. 18 0
      ccan/io/io.h
  3. 16 3
      ccan/io/poll.c
  4. 93 0
      ccan/io/test/run-21-io_close_taken_fd.c

+ 0 - 1
ccan/io/backend.h

@@ -79,7 +79,6 @@ void remove_from_always(struct io_conn *conn);
 void backend_plan_done(struct io_conn *conn);
 void backend_plan_done(struct io_conn *conn);
 
 
 void backend_wake(const void *wait);
 void backend_wake(const void *wait);
-void backend_del_conn(struct io_conn *conn);
 
 
 void io_ready(struct io_conn *conn, int pollflags);
 void io_ready(struct io_conn *conn, int pollflags);
 void io_do_always(struct io_conn *conn);
 void io_do_always(struct io_conn *conn);

+ 18 - 0
ccan/io/io.h

@@ -627,6 +627,24 @@ struct io_plan *io_close(struct io_conn *conn);
  */
  */
 struct io_plan *io_close_cb(struct io_conn *, void *unused);
 struct io_plan *io_close_cb(struct io_conn *, void *unused);
 
 
+/**
+ * io_close_taken_fd - close a connection, but remove the filedescriptor first.
+ * @conn: the connection to take the file descriptor from and close,
+ *
+ * io_close closes the file descriptor underlying the io_conn; this version does
+ * not.  Presumably you have used io_conn_fd() on it beforehand and will take
+ * care of the fd yourself.
+ *
+ * Example:
+ * static struct io_plan *steal_fd(struct io_conn *conn, int *fd)
+ * {
+ *	*fd = io_conn_fd(conn);
+ *	printf("stealing fd %i and closing\n", *fd);
+ *	return io_close_taken_fd(conn);
+ * }
+ */
+struct io_plan *io_close_taken_fd(struct io_conn *conn);
+
 /**
 /**
  * io_loop - process fds until all closed on io_break.
  * io_loop - process fds until all closed on io_break.
  * @timers - timers which are waiting to go off (or NULL for none)
  * @timers - timers which are waiting to go off (or NULL for none)

+ 16 - 3
ccan/io/poll.c

@@ -157,11 +157,12 @@ void backend_wake(const void *wait)
 	}
 	}
 }
 }
 
 
-static void destroy_conn(struct io_conn *conn)
+static void destroy_conn(struct io_conn *conn, bool close_fd)
 {
 {
 	int saved_errno = errno;
 	int saved_errno = errno;
 
 
-	close(conn->fd.fd);
+	if (close_fd)
+		close(conn->fd.fd);
 	del_fd(&conn->fd);
 	del_fd(&conn->fd);
 	/* In case it's on always list, remove it. */
 	/* In case it's on always list, remove it. */
 	list_del_init(&conn->always);
 	list_del_init(&conn->always);
@@ -173,14 +174,26 @@ static void destroy_conn(struct io_conn *conn)
 	}
 	}
 }
 }
 
 
+static void destroy_conn_close_fd(struct io_conn *conn)
+{
+	destroy_conn(conn, true);
+}
+
 bool add_conn(struct io_conn *c)
 bool add_conn(struct io_conn *c)
 {
 {
 	if (!add_fd(&c->fd, 0))
 	if (!add_fd(&c->fd, 0))
 		return false;
 		return false;
-	tal_add_destructor(c, destroy_conn);
+	tal_add_destructor(c, destroy_conn_close_fd);
 	return true;
 	return true;
 }
 }
 
 
+struct io_plan *io_close_taken_fd(struct io_conn *conn)
+{
+	tal_del_destructor(conn, destroy_conn_close_fd);
+	destroy_conn(conn, false);
+	return io_close(conn);
+}
+
 static void accept_conn(struct io_listener *l)
 static void accept_conn(struct io_listener *l)
 {
 {
 	int fd = accept(l->fd.fd, NULL, NULL);
 	int fd = accept(l->fd.fd, NULL, NULL);

+ 93 - 0
ccan/io/test/run-21-io_close_taken_fd.c

@@ -0,0 +1,93 @@
+#include <ccan/io/io.h>
+/* Include the C files directly. */
+#include <ccan/io/poll.c>
+#include <ccan/io/io.c>
+#include <ccan/tap/tap.h>
+#include <sys/wait.h>
+#include <stdio.h>
+
+#define PORT "65021"
+
+static struct io_listener *l;
+
+static struct io_plan *steal_fd(struct io_conn *conn, int *fd)
+{
+	io_close_listener(l);
+	*fd = io_conn_fd(conn);
+	return io_close_taken_fd(conn);
+}
+
+static int make_listen_fd(const char *port, struct addrinfo **info)
+{
+	int fd, on = 1;
+	struct addrinfo *addrinfo, hints;
+
+	memset(&hints, 0, sizeof(hints));
+	hints.ai_family = AF_UNSPEC;
+	hints.ai_socktype = SOCK_STREAM;
+	hints.ai_flags = AI_PASSIVE;
+	hints.ai_protocol = 0;
+
+	if (getaddrinfo(NULL, port, &hints, &addrinfo) != 0)
+		return -1;
+
+	fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
+		    addrinfo->ai_protocol);
+	if (fd < 0)
+		return -1;
+
+	setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
+	if (bind(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0) {
+		close(fd);
+		return -1;
+	}
+	if (listen(fd, 1) != 0) {
+		close(fd);
+		return -1;
+	}
+	*info = addrinfo;
+	return fd;
+}
+
+int main(void)
+{
+	struct addrinfo *addrinfo = NULL;
+	int i, fd, in_fd, status;
+	char buf[strlen("hellothere")];
+
+	/* This is how many tests you plan to run */
+	plan_tests(15);
+	fd = make_listen_fd(PORT, &addrinfo);
+	l = io_new_listener(NULL, fd, steal_fd, &in_fd);
+	fflush(stdout);
+	if (!fork()) {
+		io_close_listener(l);
+		fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
+			    addrinfo->ai_protocol);
+		if (fd < 0)
+			exit(1);
+		if (connect(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) != 0)
+			exit(2);
+		signal(SIGPIPE, SIG_IGN);
+		for (i = 0; i < strlen("hellothere"); i++) {
+			if (write(fd, "hellothere" + i, 1) != 1)
+				break;
+		}
+		close(fd);
+		freeaddrinfo(addrinfo);
+		exit(0);
+	}
+	freeaddrinfo(addrinfo);
+	ok1(io_loop(NULL, NULL) == NULL);
+
+	for (i = 0; i < strlen("hellothere"); i++)
+		ok1(read(in_fd, buf + i, 1) == 1);
+
+	ok1(memcmp(buf, "hellothere", sizeof(buf)) == 0);
+	ok1(wait(&status));
+	ok1(WIFEXITED(status));
+	ok1(WEXITSTATUS(status) == 0);
+
+	/* This exits depending on whether all tests passed */
+	return exit_status();
+}