Browse Source

crypto/shachain: detect if we're inserting a bogus hash.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
Rusty Russell 10 years ago
parent
commit
9fc0711160

+ 11 - 2
ccan/crypto/shachain/shachain.c

@@ -49,7 +49,7 @@ void shachain_init(struct shachain *shachain)
 	shachain->num_valid = 0;
 }
 
-void shachain_add_hash(struct shachain *chain,
+bool shachain_add_hash(struct shachain *chain,
 		       shachain_index_t index, const struct sha256 *hash)
 {
 	int i;
@@ -57,8 +57,16 @@ void shachain_add_hash(struct shachain *chain,
 	for (i = 0; i < chain->num_valid; i++) {
 		/* If we could derive this value, we don't need it,
 		 * not any others (since they're in order). */
-		if (can_derive(index, chain->known[i].index))
+		if (can_derive(index, chain->known[i].index)) {
+			struct sha256 expect;
+
+			/* Make sure the others derive as expected! */
+			derive(index, chain->known[i].index, hash, &expect);
+			if (memcmp(&expect, &chain->known[i].hash,
+				   sizeof(expect)) != 0)
+				return false;
 			break;
+		}
 	}
 
 	/* This can happen if you skip indices! */
@@ -66,6 +74,7 @@ void shachain_add_hash(struct shachain *chain,
 	chain->known[i].index = index;
 	chain->known[i].hash = *hash;
 	chain->num_valid = i+1;
+	return true;
 }
 
 bool shachain_get_hash(const struct shachain *chain,

+ 1 - 1
ccan/crypto/shachain/shachain.h

@@ -24,7 +24,7 @@ struct shachain {
 
 void shachain_init(struct shachain *shachain);
 
-void shachain_add_hash(struct shachain *shachain,
+bool shachain_add_hash(struct shachain *shachain,
 		       shachain_index_t index, const struct sha256 *hash);
 
 bool shachain_get_hash(const struct shachain *shachain,

+ 2 - 2
ccan/crypto/shachain/test/run-8bit.c

@@ -17,7 +17,7 @@ int main(void)
 	size_t i, j;
 
 	/* This is how many tests you plan to run */
-	plan_tests(NUM_TESTS + NUM_TESTS * (NUM_TESTS + 1) + NUM_TESTS);
+	plan_tests(NUM_TESTS * 3 + NUM_TESTS * (NUM_TESTS + 1));
 
 	memset(&seed, 0, sizeof(seed));
 	/* Generate a whole heap. */
@@ -34,7 +34,7 @@ int main(void)
 	for (i = 0; i < NUM_TESTS; i++) {
 		struct sha256 hash;
 
-		shachain_add_hash(&chain, i, &expect[i]);
+		ok1(shachain_add_hash(&chain, i, &expect[i]));
 		for (j = 0; j <= i; j++) {
 			ok1(shachain_get_hash(&chain, j, &hash));
 			ok1(memcmp(&hash, &expect[j], sizeof(hash)) == 0);

+ 38 - 0
ccan/crypto/shachain/test/run-badhash.c

@@ -0,0 +1,38 @@
+#include <ccan/crypto/shachain/shachain.h>
+/* Include the C files directly. */
+#include <ccan/crypto/shachain/shachain.c>
+#include <ccan/tap/tap.h>
+
+#define NUM_TESTS 1000
+
+int main(void)
+{
+	struct sha256 seed;
+	struct shachain chain;
+	size_t i;
+
+	plan_tests(NUM_TESTS);
+
+	memset(&seed, 0xFF, sizeof(seed));
+	shachain_init(&chain);
+
+	for (i = 0; i < NUM_TESTS; i++) {
+		struct sha256 expect;
+		unsigned int num_known = chain.num_valid;
+
+		shachain_from_seed(&seed, i, &expect);
+		/* Screw it up. */
+		expect.u.u8[0]++;
+
+		/* Either it should fail, or it couldn't derive any others. */
+		if (shachain_add_hash(&chain, i, &expect)) {
+			ok1(chain.num_valid == num_known + 1);
+			/* Fix it up in-place */
+			chain.known[num_known].hash.u.u8[0]--;
+		} else {
+			expect.u.u8[0]--;
+			ok1(shachain_add_hash(&chain, i, &expect));
+		}
+	}
+	return exit_status();
+}

+ 2 - 2
ccan/crypto/shachain/test/run.c

@@ -13,7 +13,7 @@ int main(void)
 	size_t i, j;
 
 	/* This is how many tests you plan to run */
-	plan_tests(NUM_TESTS + NUM_TESTS * (NUM_TESTS + 1) + NUM_TESTS);
+	plan_tests(NUM_TESTS * 3 + NUM_TESTS * (NUM_TESTS + 1));
 
 	memset(&seed, 0, sizeof(seed));
 	/* Generate a whole heap. */
@@ -30,7 +30,7 @@ int main(void)
 	for (i = 0; i < NUM_TESTS; i++) {
 		struct sha256 hash;
 
-		shachain_add_hash(&chain, i, &expect[i]);
+		ok1(shachain_add_hash(&chain, i, &expect[i]));
 		for (j = 0; j <= i; j++) {
 			ok1(shachain_get_hash(&chain, j, &hash));
 			ok1(memcmp(&hash, &expect[j], sizeof(hash)) == 0);