Browse Source

typesafe_cb, hashtable: revise typesafe_cb arg order, neaten.
hashtable: make traverse callback typesafe.

Rusty Russell 16 years ago
parent
commit
b1c867121f

+ 4 - 4
ccan/asearch/asearch.h

@@ -23,10 +23,10 @@
 #if HAVE_TYPEOF
 #if HAVE_TYPEOF
 #define asearch(key, base, num, cmp)					\
 #define asearch(key, base, num, cmp)					\
 	((__typeof__(*(base))*)(bsearch((key), (base), (num), sizeof(*(base)), \
 	((__typeof__(*(base))*)(bsearch((key), (base), (num), sizeof(*(base)), \
-	     cast_if_type((cmp),					\
-			  int (*)(const __typeof__(*(key)) *,		\
-				  const __typeof__(*(base)) *),		\
-			  int (*)(const void *, const void *)))))
+		cast_if_type(int (*)(const void *, const void *),	\
+			     (cmp),					\
+			     int (*)(const __typeof__(*(key)) *,	\
+				     const __typeof__(*(base)) *)))))
 #else
 #else
 #define asearch(key, base, num, cmp)				\
 #define asearch(key, base, num, cmp)				\
 	(bsearch((key), (base), (num), sizeof(*(base)),		\
 	(bsearch((key), (base), (num), sizeof(*(base)),		\

+ 3 - 9
ccan/asort/asort.h

@@ -16,19 +16,13 @@
  * The @cmp function should exactly match the type of the @base and
  * The @cmp function should exactly match the type of the @base and
  * @ctx arguments.  Otherwise it can take three const void *.
  * @ctx arguments.  Otherwise it can take three const void *.
  */
  */
-#if HAVE_TYPEOF
 #define asort(base, num, cmp, ctx)					\
 #define asort(base, num, cmp, ctx)					\
 _asort((base), (num), sizeof(*(base)),					\
 _asort((base), (num), sizeof(*(base)),					\
-       cast_if_type((cmp),						\
+       cast_if_type(int (*)(const void *, const void *, const void *),	\
+		    (cmp),						\
 		    int (*)(const __typeof__(*(base)) *,		\
 		    int (*)(const __typeof__(*(base)) *,		\
 			    const __typeof__(*(base)) *,		\
 			    const __typeof__(*(base)) *,		\
-			    __typeof__(ctx)),				\
-		    int (*)(const void *, const void *, const void *)), (ctx))
-#else
-#define asort(base, num, cmp, ctx)				\
-	_asort((base), (num), sizeof(*(base)),			\
-	       (int (*)(const void *, const void *, const void *))(cmp), ctx)
-#endif
+			    __typeof__(ctx))), (ctx))
 
 
 void _asort(void *base, size_t nmemb, size_t size,
 void _asort(void *base, size_t nmemb, size_t size,
 	    int(*compar)(const void *, const void *, const void *),
 	    int(*compar)(const void *, const void *, const void *),

+ 4 - 4
ccan/hashtable/hashtable.c

@@ -186,10 +186,10 @@ bool hashtable_del(struct hashtable *ht, unsigned long hash, const void *p)
 	delete_run(ht, i);
 	delete_run(ht, i);
 	return true;
 	return true;
 }
 }
-	
-	
-void hashtable_traverse(struct hashtable *ht, bool (*cb)(void *p, void *cbarg),
-			void *cbarg)
+
+void _hashtable_traverse(struct hashtable *ht,
+			 bool (*cb)(void *p, void *cbarg),
+			 void *cbarg)
 {
 {
 	size_t i;
 	size_t i;
 
 

+ 18 - 2
ccan/hashtable/hashtable.h

@@ -2,6 +2,7 @@
 #define CCAN_HASHTABLE_H
 #define CCAN_HASHTABLE_H
 #include "config.h"
 #include "config.h"
 #include <stdbool.h>
 #include <stdbool.h>
+#include <ccan/typesafe_cb/typesafe_cb.h>
 
 
 struct hashtable;
 struct hashtable;
 
 
@@ -63,12 +64,27 @@ bool hashtable_del(struct hashtable *ht, unsigned long hash, const void *p);
 /**
 /**
  * hashtable_traverse - call a function on every pointer in hash tree
  * hashtable_traverse - call a function on every pointer in hash tree
  * @ht: the hashtable
  * @ht: the hashtable
+ * @type: the type of the element in the hashtable.
  * @cb: the callback: returns true to abort traversal.
  * @cb: the callback: returns true to abort traversal.
  * @cbarg: the argument to the callback
  * @cbarg: the argument to the callback
  *
  *
  * Note that your traversal callback may delete any entry (it won't crash),
  * Note that your traversal callback may delete any entry (it won't crash),
  * but it may make the traverse unreliable.
  * but it may make the traverse unreliable.
  */
  */
-void hashtable_traverse(struct hashtable *ht, bool (*cb)(void *p, void *cbarg),
-			void *cbarg);
+#define hashtable_traverse(ht, type, cb, cbarg)				\
+	_hashtable_traverse(ht, cast_if_type(bool (*)(void *, void *),	\
+			     cast_if_any(bool (*)(void *,		\
+						  void *), (cb),	\
+					 bool (*)(const type *,		\
+						  const typeof(*cbarg) *), \
+					 bool (*)(type *,		\
+						  const typeof(*cbarg) *), \
+					 bool (*)(const type *,		\
+						  typeof(*cbarg) *)),	\
+					     bool (*)(type *,		\
+						      typeof(*cbarg) *)), \
+			    cbarg)
+
+void _hashtable_traverse(struct hashtable *ht,
+			 bool (*cb)(void *p, void *cbarg), void *cbarg);
 #endif /* CCAN_HASHTABLE_H */
 #endif /* CCAN_HASHTABLE_H */

+ 15 - 24
ccan/hashtable/test/run.c

@@ -87,49 +87,40 @@ struct travarg {
 	uint64_t *val;
 	uint64_t *val;
 };
 };
 
 
-static bool count(void *p, void *cbarg)
+static bool count(void *p, struct travarg *travarg)
 {
 {
-	struct travarg *travarg = cbarg;
 	travarg->count++;
 	travarg->count++;
 	travarg->touched[*(uint64_t *)p]++;
 	travarg->touched[*(uint64_t *)p]++;
 	return false;
 	return false;
 }
 }
 
 
-static bool delete_self(void *p, void *cbarg)
+static bool delete_self(uint64_t *p, struct travarg *travarg)
 {
 {
-	struct travarg *travarg = cbarg;
-	uint64_t val = *(uint64_t *)p;
-
 	travarg->count++;
 	travarg->count++;
-	travarg->touched[val]++;
+	travarg->touched[*p]++;
 	return !hashtable_del(travarg->ht, hash(p, NULL), p);
 	return !hashtable_del(travarg->ht, hash(p, NULL), p);
 }
 }
 
 
-static bool delete_next(void *p, void *cbarg)
+static bool delete_next(uint64_t *p, struct travarg *travarg)
 {
 {
-	struct travarg *travarg = cbarg;
-	uint64_t val = *(uint64_t *)p;
-	uint64_t *next = &travarg->val[(val + 1) % NUM_VALS];
+	uint64_t *next = &travarg->val[((*p) + 1) % NUM_VALS];
 
 
 	travarg->count++;
 	travarg->count++;
-	travarg->touched[val]++;
+	travarg->touched[*p]++;
 	return !hashtable_del(travarg->ht, hash(next, NULL), next);
 	return !hashtable_del(travarg->ht, hash(next, NULL), next);
 }
 }
 
 
-static bool delete_prev(void *p, void *cbarg)
+static bool delete_prev(uint64_t *p, struct travarg *travarg)
 {
 {
-	struct travarg *travarg = cbarg;
-	uint64_t val = *(uint64_t *)p;
-	uint64_t *prev = &travarg->val[(val - 1) % NUM_VALS];
+	uint64_t *prev = &travarg->val[((*p) - 1) % NUM_VALS];
 
 
 	travarg->count++;
 	travarg->count++;
-	travarg->touched[val]++;
+	travarg->touched[*p]++;
 	return !hashtable_del(travarg->ht, hash(prev, NULL), prev);
 	return !hashtable_del(travarg->ht, hash(prev, NULL), prev);
 }
 }
 
 
-static bool stop_halfway(void *p, void *cbarg)
+static bool stop_halfway(void *p, struct travarg *travarg)
 {
 {
-	struct travarg *travarg = cbarg;
 	travarg->count++;
 	travarg->count++;
 	travarg->touched[*(uint64_t *)p]++;
 	travarg->touched[*(uint64_t *)p]++;
 
 
@@ -207,13 +198,13 @@ int main(int argc, char *argv[])
 	travarg.count = 0;
 	travarg.count = 0;
 
 
 	/* Traverse. */
 	/* Traverse. */
-	hashtable_traverse(ht, count, &travarg);
+	hashtable_traverse(ht, void, count, &travarg);
 	ok1(travarg.count == NUM_VALS);
 	ok1(travarg.count == NUM_VALS);
 	check_all_touched_once(&travarg);
 	check_all_touched_once(&travarg);
 
 
 	memset(travarg.touched, 0, sizeof(travarg.touched));
 	memset(travarg.touched, 0, sizeof(travarg.touched));
 	travarg.count = 0;
 	travarg.count = 0;
-	hashtable_traverse(ht, stop_halfway, &travarg);
+	hashtable_traverse(ht, void, stop_halfway, &travarg);
 	ok1(travarg.count == NUM_VALS / 2);
 	ok1(travarg.count == NUM_VALS / 2);
 	check_only_touched_once(&travarg);
 	check_only_touched_once(&travarg);
 
 
@@ -222,7 +213,7 @@ int main(int argc, char *argv[])
 	i = 0;
 	i = 0;
 	/* Delete until we make no more progress. */
 	/* Delete until we make no more progress. */
 	for (;;) {
 	for (;;) {
-		hashtable_traverse(ht, delete_self, &travarg);
+		hashtable_traverse(ht, uint64_t, delete_self, &travarg);
 		if (travarg.count == i || travarg.count > NUM_VALS)
 		if (travarg.count == i || travarg.count > NUM_VALS)
 			break;
 			break;
 		i = travarg.count;
 		i = travarg.count;
@@ -233,14 +224,14 @@ int main(int argc, char *argv[])
 	memset(travarg.touched, 0, sizeof(travarg.touched));
 	memset(travarg.touched, 0, sizeof(travarg.touched));
 	travarg.count = 0;
 	travarg.count = 0;
 	refill_vals(ht, val, NUM_VALS);
 	refill_vals(ht, val, NUM_VALS);
-	hashtable_traverse(ht, delete_next, &travarg);
+	hashtable_traverse(ht, uint64_t, delete_next, &travarg);
 	ok1(travarg.count <= NUM_VALS);
 	ok1(travarg.count <= NUM_VALS);
 	check_only_touched_once(&travarg);
 	check_only_touched_once(&travarg);
 
 
 	memset(travarg.touched, 0, sizeof(travarg.touched));
 	memset(travarg.touched, 0, sizeof(travarg.touched));
 	travarg.count = 0;
 	travarg.count = 0;
 	refill_vals(ht, val, NUM_VALS);
 	refill_vals(ht, val, NUM_VALS);
-	hashtable_traverse(ht, delete_prev, &travarg);
+	hashtable_traverse(ht, uint64_t, delete_prev, &travarg);
 	ok1(travarg.count <= NUM_VALS);
 	ok1(travarg.count <= NUM_VALS);
 	check_only_touched_once(&travarg);
 	check_only_touched_once(&travarg);
 
 

+ 2 - 5
ccan/likely/likely.c

@@ -88,11 +88,8 @@ static double right_ratio(const struct trace *t)
 	return (double)t->right / t->count;
 	return (double)t->right / t->count;
 }
 }
 
 
-static bool get_stats(void *elem, void *vinfo)
+static bool get_stats(struct trace *trace, struct get_stats_info *info)
 {
 {
-	struct trace *trace = elem;
-	struct get_stats_info *info = vinfo;
-
 	if (trace->count < info->min_hits)
 	if (trace->count < info->min_hits)
 		return false;
 		return false;
 
 
@@ -116,7 +113,7 @@ const char *likely_stats(unsigned int min_hits, unsigned int percent)
 	info.worst_ratio = 2;
 	info.worst_ratio = 2;
 
 
 	/* This is O(n), but it's not likely called that often. */
 	/* This is O(n), but it's not likely called that often. */
-	hashtable_traverse(htable, get_stats, &info);
+	hashtable_traverse(htable, struct trace, get_stats, &info);
 
 
 	if (info.worst_ratio * 100 > percent)
 	if (info.worst_ratio * 100 > percent)
 		return NULL;
 		return NULL;

+ 87 - 0
ccan/typesafe_cb/_info

@@ -46,6 +46,93 @@
  * cast_if_type() and friend become an unconditional cast, so your
  * cast_if_type() and friend become an unconditional cast, so your
  * code will compile but you won't get type checking.
  * code will compile but you won't get type checking.
  *
  *
+ * Example:
+ *	#include <ccan/typesafe_cb/typesafe_cb.h>
+ *	#include <stdlib.h>
+ *	#include <stdio.h>
+ *
+ *	// Generic callback infrastructure.
+ *	struct callback {
+ *		struct callback *next;
+ *		int value;
+ *		int (*callback)(int value, void *arg);
+ *		void *arg;
+ *	};
+ *	static struct callback *callbacks;
+ *	
+ *	static void _register_callback(int value, int (*cb)(int, void *),
+ *				       void *arg)
+ *	{
+ *		struct callback *new = malloc(sizeof(*new));
+ *		new->next = callbacks;
+ *		new->value = value;
+ *		new->callback = cb;
+ *		new->arg = arg;
+ *		callbacks = new;
+ *	}
+ *	#define register_callback(value, cb, arg)			\
+ *		_register_callback(value,				\
+ *				   typesafe_cb_preargs(int, (cb), (arg), int),\
+ *				   (arg))
+ *	
+ *	static struct callback *find_callback(int value)
+ *	{
+ *		struct callback *i;
+ *	
+ *		for (i = callbacks; i; i = i->next)
+ *			if (i->value == value)
+ *				return i;
+ *		return NULL;
+ *	}   
+ *
+ *	// Define several silly callbacks.  Note they don't use void *!
+ *	#define DEF_CALLBACK(name, op)			\
+ *		static int name(int val, const int *arg)\
+ *		{					\
+ *			printf("%s", #op);		\
+ *			return val op *arg;		\
+ *		}
+ *	DEF_CALLBACK(multiply, *);
+ *	DEF_CALLBACK(add, +);
+ *	DEF_CALLBACK(divide, /);
+ *	DEF_CALLBACK(sub, -);
+ *	DEF_CALLBACK(or, |);
+ *	DEF_CALLBACK(and, &);
+ *	DEF_CALLBACK(xor, ^);
+ *	DEF_CALLBACK(assign, =);
+ *
+ *	// Silly game to find the longest chain of values.
+ *	int main(int argc, char *argv[])
+ *	{
+ *		int i, run = 1, num = argv[1] ? atoi(argv[1]) : 0;
+ *	
+ *		for (i = 1; i < 1024;) {
+ *			// Since run is an int, compiler checks "add" does too.
+ *			register_callback(i++, add, &run);
+ *			register_callback(i++, divide, &run);
+ *			register_callback(i++, sub, &run);
+ *			register_callback(i++, multiply, &run);
+ *			register_callback(i++, or, &run);
+ *			register_callback(i++, and, &run);
+ *			register_callback(i++, xor, &run);
+ *			register_callback(i++, assign, &run);
+ *		}
+ *	
+ *		printf("%i ", num);
+ *		while (run < 56) {
+ *			struct callback *cb = find_callback(num % i);
+ *			if (!cb) {
+ *				printf("-> STOP\n");
+ *				return 1;
+ *			}
+ *			num = cb->callback(num, cb->arg);
+ *			printf("->%i ", num);
+ *			run++;
+ *		}
+ *		printf("-> Winner!\n");
+ *		return 0;
+ *	}
+ *
  * Licence: LGPL (2 or any later version)
  * Licence: LGPL (2 or any later version)
  */
  */
 int main(int argc, char *argv[])
 int main(int argc, char *argv[])

+ 1 - 1
ccan/typesafe_cb/test/compile_fail-cast_if_type.c

@@ -7,7 +7,7 @@ void _set_some_value(void *val)
 }
 }
 
 
 #define set_some_value(expr)						\
 #define set_some_value(expr)						\
-	_set_some_value(cast_if_type((expr), unsigned long, void *))
+	_set_some_value(cast_if_type(void *, (expr), unsigned long))
 
 
 int main(int argc, char *argv[])
 int main(int argc, char *argv[])
 {
 {

+ 1 - 1
ccan/typesafe_cb/test/run.c

@@ -11,7 +11,7 @@ static void _set_some_value(void *val)
 }
 }
 
 
 #define set_some_value(expr)						\
 #define set_some_value(expr)						\
-	_set_some_value(cast_if_type((expr), unsigned long, void *))
+	_set_some_value(cast_if_type(void *, (expr), unsigned long))
 
 
 static void _callback_onearg(void (*fn)(void *arg), void *arg)
 static void _callback_onearg(void (*fn)(void *arg), void *arg)
 {
 {

+ 48 - 37
ccan/typesafe_cb/typesafe_cb.h

@@ -5,9 +5,9 @@
 #if HAVE_TYPEOF && HAVE_BUILTIN_CHOOSE_EXPR && HAVE_BUILTIN_TYPES_COMPATIBLE_P
 #if HAVE_TYPEOF && HAVE_BUILTIN_CHOOSE_EXPR && HAVE_BUILTIN_TYPES_COMPATIBLE_P
 /**
 /**
  * cast_if_type - only cast an expression if it is of a given type
  * cast_if_type - only cast an expression if it is of a given type
+ * @desttype: the type to cast to
  * @expr: the expression to cast
  * @expr: the expression to cast
  * @oktype: the type we allow
  * @oktype: the type we allow
- * @desttype: the type to cast to
  *
  *
  * This macro is used to create functions which allow multiple types.
  * This macro is used to create functions which allow multiple types.
  * The result of this macro is used somewhere that a @desttype type is
  * The result of this macro is used somewhere that a @desttype type is
@@ -27,15 +27,41 @@
  *	// We can take either an unsigned long or a void *.
  *	// We can take either an unsigned long or a void *.
  *	void _set_some_value(void *val);
  *	void _set_some_value(void *val);
  *	#define set_some_value(expr)			\
  *	#define set_some_value(expr)			\
- *		_set_some_value(cast_if_type((expr), unsigned long, void *))
+ *		_set_some_value(cast_if_type(void *, (expr), unsigned long))
  */
  */
-#define cast_if_type(expr, oktype, desttype)				\
+#define cast_if_type(desttype, expr, oktype)				\
 __builtin_choose_expr(__builtin_types_compatible_p(typeof(1?(expr):0), oktype), \
 __builtin_choose_expr(__builtin_types_compatible_p(typeof(1?(expr):0), oktype), \
 			(desttype)(expr), (expr))
 			(desttype)(expr), (expr))
 #else
 #else
 #define cast_if_type(expr, oktype, desttype) ((desttype)(expr))
 #define cast_if_type(expr, oktype, desttype) ((desttype)(expr))
 #endif
 #endif
 
 
+/**
+ * cast_if_any - only cast an expression if it is one of the three given types
+ * @desttype: the type to cast to
+ * @expr: the expression to cast
+ * @ok1: the first type we allow
+ * @ok2: the second type we allow
+ * @ok3: the third type we allow
+ *
+ * This is a convenient wrapper for multiple cast_if_type() calls.  You can
+ * chain them inside each other (ie. use cast_if_any() for expr) if you need
+ * more than 3 arguments.
+ *
+ * Example:
+ *	// We can take either a long, unsigned long, void * or a const void *.
+ *	void _set_some_value(void *val);
+ *	#define set_some_value(expr)					\
+ *		_set_some_value(cast_if_any(void *, (expr),		\
+ *					    long, unsigned long, const void *))
+ */
+#define cast_if_any(desttype, expr, ok1, ok2, ok3)			\
+	cast_if_type(desttype,						\
+		     cast_if_type(desttype,				\
+				  cast_if_type(desttype, (expr), ok1),	\
+				  ok2),					\
+		     ok3)
+
 /**
 /**
  * typesafe_cb - cast a callback function if it matches the arg
  * typesafe_cb - cast a callback function if it matches the arg
  * @rtype: the return type of the callback function
  * @rtype: the return type of the callback function
@@ -54,14 +80,11 @@ __builtin_choose_expr(__builtin_types_compatible_p(typeof(1?(expr):0), oktype),
  *	#define register_callback(fn, arg) \
  *	#define register_callback(fn, arg) \
  *		_register_callback(typesafe_cb(void, (fn), (arg)), (arg))
  *		_register_callback(typesafe_cb(void, (fn), (arg)), (arg))
  */
  */
-#define typesafe_cb(rtype, fn, arg)					\
-	cast_if_type(cast_if_type(cast_if_type((fn),			\
-					       rtype (*)(const typeof(*arg)*), \
-					       rtype (*)(void *)),	\
-				  rtype (*)(volatile typeof(*arg) *),	\
-				  rtype (*)(void *)),			\
-		     rtype (*)(typeof(arg)),				\
-		     rtype (*)(void *))
+#define typesafe_cb(rtype, fn, arg)			\
+	cast_if_any(rtype (*)(void *), (fn),		\
+		    rtype (*)(typeof(*arg)*),		\
+		    rtype (*)(const typeof(*arg)*),	\
+		    rtype (*)(volatile typeof(*arg)*))
 
 
 /**
 /**
  * typesafe_cb_const - cast a const callback function if it matches the arg
  * typesafe_cb_const - cast a const callback function if it matches the arg
@@ -82,8 +105,8 @@ __builtin_choose_expr(__builtin_types_compatible_p(typeof(1?(expr):0), oktype),
  *		_register_callback(typesafe_cb_const(void, (fn), (arg)), (arg))
  *		_register_callback(typesafe_cb_const(void, (fn), (arg)), (arg))
  */
  */
 #define typesafe_cb_const(rtype, fn, arg)				\
 #define typesafe_cb_const(rtype, fn, arg)				\
-	cast_if_type((fn),						\
-		     rtype (*)(const typeof(*arg)*), rtype (*)(const void *))
+	cast_if_type(rtype (*)(const void *), (fn),			\
+		     rtype (*)(const typeof(*arg)*))
 
 
 /**
 /**
  * typesafe_cb_preargs - cast a callback function if it matches the arg
  * typesafe_cb_preargs - cast a callback function if it matches the arg
@@ -101,16 +124,10 @@ __builtin_choose_expr(__builtin_types_compatible_p(typeof(1?(expr):0), oktype),
  *				   (arg))
  *				   (arg))
  */
  */
 #define typesafe_cb_preargs(rtype, fn, arg, ...)			\
 #define typesafe_cb_preargs(rtype, fn, arg, ...)			\
-	cast_if_type(cast_if_type(cast_if_type((fn),			\
-					       rtype (*)(__VA_ARGS__,	\
-							 const typeof(*arg) *),\
-					       rtype (*)(__VA_ARGS__,	\
-							 void *)),	\
-				  rtype (*)(__VA_ARGS__,		\
-					    volatile typeof(*arg) *),	\
-				  rtype (*)(__VA_ARGS__, void *)),	\
-		     rtype (*)(__VA_ARGS__, typeof(arg)),		\
-		     rtype (*)(__VA_ARGS__, void *))
+	cast_if_any(rtype (*)(__VA_ARGS__, void *), (fn),		\
+		    rtype (*)(__VA_ARGS__, typeof(arg)),		\
+		    rtype (*)(__VA_ARGS__, const typeof(*arg) *),	\
+		    rtype (*)(__VA_ARGS__, volatile typeof(*arg) *))
 
 
 /**
 /**
  * typesafe_cb_postargs - cast a callback function if it matches the arg
  * typesafe_cb_postargs - cast a callback function if it matches the arg
@@ -124,20 +141,14 @@ __builtin_choose_expr(__builtin_types_compatible_p(typeof(1?(expr):0), oktype),
  * Example:
  * Example:
  *	void _register_callback(void (*fn)(void *arg, int), void *arg);
  *	void _register_callback(void (*fn)(void *arg, int), void *arg);
  *	#define register_callback(fn, arg) \
  *	#define register_callback(fn, arg) \
- *		_register_callback(typesafe_cb_preargs(void, (fn), (arg), int),\
+ *		_register_callback(typesafe_cb_postargs(void, (fn), (arg), int),\
  *				   (arg))
  *				   (arg))
  */
  */
 #define typesafe_cb_postargs(rtype, fn, arg, ...)			\
 #define typesafe_cb_postargs(rtype, fn, arg, ...)			\
-	cast_if_type(cast_if_type(cast_if_type((fn),			\
-					       rtype (*)(const typeof(*arg) *, \
-							 __VA_ARGS__),	\
-					       rtype (*)(void *,	\
-							 __VA_ARGS__)), \
-				  rtype (*)(volatile typeof(*arg) *,	\
-					    __VA_ARGS__),		\
-				  rtype (*)(void *, __VA_ARGS__)),	\
-		     rtype (*)(typeof(arg), __VA_ARGS__),		\
-		     rtype (*)(void *, __VA_ARGS__))
+	cast_if_any(rtype (*)(void *, __VA_ARGS__), (fn),		\
+		    rtype (*)(typeof(arg), __VA_ARGS__),		\
+		    rtype (*)(const typeof(*arg) *, __VA_ARGS__),	\
+		    rtype (*)(volatile typeof(*arg) *, __VA_ARGS__))
 
 
 /**
 /**
  * typesafe_cb_cmp - cast a compare function if it matches the arg
  * typesafe_cb_cmp - cast a compare function if it matches the arg
@@ -160,7 +171,7 @@ __builtin_choose_expr(__builtin_types_compatible_p(typeof(1?(expr):0), oktype),
  *			  typesafe_cb_cmp(int, (cmpfn), (base)), (arg))
  *			  typesafe_cb_cmp(int, (cmpfn), (base)), (arg))
  */
  */
 #define typesafe_cb_cmp(rtype, cmpfn, arg)				\
 #define typesafe_cb_cmp(rtype, cmpfn, arg)				\
-	cast_if_type((cmpfn),						\
-		     rtype (*)(const typeof(*arg)*, const typeof(*arg)*), \
-		     rtype (*)(const void *, const void *))
+	cast_if_type(rtype (*)(const void *, const void *), (cmpfn),	\
+		     rtype (*)(const typeof(*arg)*, const typeof(*arg)*))
+		     
 #endif /* CCAN_CAST_IF_TYPE_H */
 #endif /* CCAN_CAST_IF_TYPE_H */