Browse Source

RPC: Clean up mcast socket with tidyup_socket

Luke Dashjr 11 years ago
parent
commit
4918e81b48
1 changed files with 21 additions and 9 deletions
  1. 21 9
      api.c

+ 21 - 9
api.c

@@ -3368,6 +3368,14 @@ void _tidyup_socket(SOCKETTYPE * const sockp)
 	}
 	}
 }
 }
 
 
+static
+void tidyup_socket(void * const arg)
+{
+	mutex_lock(&quit_restart_lock);
+	_tidyup_socket(arg);
+	mutex_unlock(&quit_restart_lock);
+}
+
 static void tidyup(__maybe_unused void *arg)
 static void tidyup(__maybe_unused void *arg)
 {
 {
 	mutex_lock(&quit_restart_lock);
 	mutex_lock(&quit_restart_lock);
@@ -3692,7 +3700,7 @@ static void mcast()
 	struct sockaddr_in came_from;
 	struct sockaddr_in came_from;
 	time_t bindstart;
 	time_t bindstart;
 	const char *binderror;
 	const char *binderror;
-	SOCKETTYPE mcast_sock;
+	SOCKETTYPE *mcastsock;
 	SOCKETTYPE reply_sock;
 	SOCKETTYPE reply_sock;
 	socklen_t came_from_siz;
 	socklen_t came_from_siz;
 	char *connectaddr;
 	char *connectaddr;
@@ -3715,10 +3723,14 @@ static void mcast()
 		quit(1, "Invalid Multicast Address");
 		quit(1, "Invalid Multicast Address");
 	grp.imr_interface.s_addr = INADDR_ANY;
 	grp.imr_interface.s_addr = INADDR_ANY;
 
 
-	mcast_sock = socket(AF_INET, SOCK_DGRAM, 0);
+	mcastsock = malloc(sizeof(*mcastsock));
+	*mcastsock = INVSOCK;
+	pthread_cleanup_push(tidyup_socket, mcastsock);
+	
+	*mcastsock = socket(AF_INET, SOCK_DGRAM, 0);
 
 
 	int optval = 1;
 	int optval = 1;
-	if (SOCKETFAIL(setsockopt(mcast_sock, SOL_SOCKET, SO_REUSEADDR, (void *)(&optval), sizeof(optval)))) {
+	if (SOCKETFAIL(setsockopt(*mcastsock, SOL_SOCKET, SO_REUSEADDR, (void *)(&optval), sizeof(optval)))) {
 		applog(LOG_ERR, "API mcast setsockopt SO_REUSEADDR failed (%s)%s", SOCKERRMSG, MUNAVAILABLE);
 		applog(LOG_ERR, "API mcast setsockopt SO_REUSEADDR failed (%s)%s", SOCKERRMSG, MUNAVAILABLE);
 		goto die;
 		goto die;
 	}
 	}
@@ -3732,7 +3744,7 @@ static void mcast()
 	bound = 0;
 	bound = 0;
 	bindstart = time(NULL);
 	bindstart = time(NULL);
 	while (bound == 0) {
 	while (bound == 0) {
-		if (SOCKETFAIL(bind(mcast_sock, (struct sockaddr *)(&listen), sizeof(listen)))) {
+		if (SOCKETFAIL(bind(*mcastsock, (struct sockaddr *)(&listen), sizeof(listen)))) {
 			binderror = SOCKERRMSG;
 			binderror = SOCKERRMSG;
 			if ((time(NULL) - bindstart) > 61)
 			if ((time(NULL) - bindstart) > 61)
 				break;
 				break;
@@ -3747,7 +3759,7 @@ static void mcast()
 		goto die;
 		goto die;
 	}
 	}
 
 
-	if (SOCKETFAIL(setsockopt(mcast_sock, IPPROTO_IP, IP_ADD_MEMBERSHIP, (void *)(&grp), sizeof(grp)))) {
+	if (SOCKETFAIL(setsockopt(*mcastsock, IPPROTO_IP, IP_ADD_MEMBERSHIP, (void *)(&grp), sizeof(grp)))) {
 		applog(LOG_ERR, "API mcast join failed (%s)%s", SOCKERRMSG, MUNAVAILABLE);
 		applog(LOG_ERR, "API mcast join failed (%s)%s", SOCKERRMSG, MUNAVAILABLE);
 		goto die;
 		goto die;
 	}
 	}
@@ -3764,10 +3776,10 @@ static void mcast()
 
 
 		count++;
 		count++;
 		came_from_siz = sizeof(came_from);
 		came_from_siz = sizeof(came_from);
-		if (SOCKETFAIL(rep = recvfrom(mcast_sock, buf, sizeof(buf),
+		if (SOCKETFAIL(rep = recvfrom(*mcastsock, buf, sizeof(buf),
 						0, (struct sockaddr *)(&came_from), &came_from_siz))) {
 						0, (struct sockaddr *)(&came_from), &came_from_siz))) {
 			applog(LOG_DEBUG, "API mcast failed count=%d (%s) (%d)",
 			applog(LOG_DEBUG, "API mcast failed count=%d (%s) (%d)",
-					count, SOCKERRMSG, (int)mcast_sock);
+					count, SOCKERRMSG, (int)*mcastsock);
 			continue;
 			continue;
 		}
 		}
 
 
@@ -3821,8 +3833,8 @@ static void mcast()
 	}
 	}
 
 
 die:
 die:
-
-	CLOSESOCKET(mcast_sock);
+	;  // statement in case pthread_cleanup_pop doesn't start with one
+	pthread_cleanup_pop(true);
 }
 }
 
 
 static void *mcast_thread(void *userdata)
 static void *mcast_thread(void *userdata)