From 60dee4611cb90af694c75d3e7ae67218d621acbe Mon Sep 17 00:00:00 2001
From: akwizgran <michael@briarproject.org>
Date: Wed, 10 Apr 2013 12:48:25 +0100
Subject: [PATCH] Validate key derivation inputs: always 32 bytes, never blank.

---
 .../sf/briar/crypto/CryptoComponentImpl.java  | 38 ++++++++++++++++++-
 .../OutgoingSimplexConnectionTest.java        |  2 +
 2 files changed, 38 insertions(+), 2 deletions(-)

diff --git a/briar-core/src/net/sf/briar/crypto/CryptoComponentImpl.java b/briar-core/src/net/sf/briar/crypto/CryptoComponentImpl.java
index e8ae752a46..8b634030aa 100644
--- a/briar-core/src/net/sf/briar/crypto/CryptoComponentImpl.java
+++ b/briar-core/src/net/sf/briar/crypto/CryptoComponentImpl.java
@@ -23,6 +23,7 @@ import java.security.spec.ECFieldFp;
 import java.security.spec.ECParameterSpec;
 import java.security.spec.ECPoint;
 import java.security.spec.EllipticCurve;
+import java.util.Arrays;
 
 import javax.crypto.Cipher;
 import javax.crypto.KeyAgreement;
@@ -82,6 +83,8 @@ class CryptoComponentImpl implements CryptoComponent {
 	// Blank plaintext for key derivation
 	private static final byte[] KEY_DERIVATION_BLANK_PLAINTEXT =
 			new byte[SECRET_KEY_BYTES];
+	// Blank secret for argument validation
+	private static final byte[] BLANK_SECRET = new byte[SECRET_KEY_BYTES];
 
 	// Parameters for NIST elliptic curve P-384 - see "Suite B Implementer's
 	// Guide to NIST SP 800-56A", section A.2
@@ -217,6 +220,10 @@ class CryptoComponentImpl implements CryptoComponent {
 	}
 
 	public int[] deriveConfirmationCodes(byte[] secret) {
+		if(secret.length != SECRET_KEY_BYTES)
+			throw new IllegalArgumentException();
+		if(Arrays.equals(secret, BLANK_SECRET))
+			throw new IllegalArgumentException();
 		byte[] alice = counterModeKdf(secret, CODE, 0);
 		byte[] bob = counterModeKdf(secret, CODE, 1);
 		int[] codes = new int[2];
@@ -228,6 +235,10 @@ class CryptoComponentImpl implements CryptoComponent {
 	}
 
 	public byte[][] deriveInvitationNonces(byte[] secret) {
+		if(secret.length != SECRET_KEY_BYTES)
+			throw new IllegalArgumentException();
+		if(Arrays.equals(secret, BLANK_SECRET))
+			throw new IllegalArgumentException();
 		byte[] alice = counterModeKdf(secret, NONCE, 0);
 		byte[] bob = counterModeKdf(secret, NONCE, 1);
 		return new byte[][] { alice, bob };
@@ -235,7 +246,6 @@ class CryptoComponentImpl implements CryptoComponent {
 
 	public byte[] deriveMasterSecret(byte[] theirPublicKey,
 			KeyPair ourKeyPair, boolean alice) throws GeneralSecurityException {
-		PublicKey theirPub = agreementKeyParser.parsePublicKey(theirPublicKey);
 		MessageDigest messageDigest = getMessageDigest();
 		byte[] ourPublicKey = ourKeyPair.getPublic().getEncoded();
 		byte[] ourHash = messageDigest.digest(ourPublicKey);
@@ -249,6 +259,7 @@ class CryptoComponentImpl implements CryptoComponent {
 			bobInfo = ourHash;
 		}
 		PrivateKey ourPriv = ourKeyPair.getPrivate();
+		PublicKey theirPub = agreementKeyParser.parsePublicKey(theirPublicKey);
 		// The raw secret comes from the key agreement algorithm
 		byte[] raw = deriveSharedSecret(ourPriv, theirPub);
 		// Derive the cooked secret from the raw secret using the
@@ -269,23 +280,41 @@ class CryptoComponentImpl implements CryptoComponent {
 	}
 
 	public byte[] deriveInitialSecret(byte[] secret, int transportIndex) {
+		if(secret.length != SECRET_KEY_BYTES)
+			throw new IllegalArgumentException();
+		if(Arrays.equals(secret, BLANK_SECRET))
+			throw new IllegalArgumentException();
 		if(transportIndex < 0) throw new IllegalArgumentException();
 		return counterModeKdf(secret, FIRST, transportIndex);
 	}
 
 	public byte[] deriveNextSecret(byte[] secret, long period) {
+		if(secret.length != SECRET_KEY_BYTES)
+			throw new IllegalArgumentException();
+		if(Arrays.equals(secret, BLANK_SECRET))
+			throw new IllegalArgumentException();
 		if(period < 0 || period > MAX_32_BIT_UNSIGNED)
 			throw new IllegalArgumentException();
 		return counterModeKdf(secret, ROTATE, period);
 	}
 
 	public ErasableKey deriveTagKey(byte[] secret, boolean alice) {
+		if(secret.length != SECRET_KEY_BYTES)
+			throw new IllegalArgumentException();
+		if(Arrays.equals(secret, BLANK_SECRET))
+			throw new IllegalArgumentException();
 		if(alice) return deriveKey(secret, A_TAG, 0);
 		else return deriveKey(secret, B_TAG, 0);
 	}
 
 	public ErasableKey deriveFrameKey(byte[] secret, long connection,
 			boolean alice, boolean initiator) {
+		if(secret.length != SECRET_KEY_BYTES)
+			throw new IllegalArgumentException();
+		if(Arrays.equals(secret, BLANK_SECRET))
+			throw new IllegalArgumentException();
+		if(connection < 0 || connection > MAX_32_BIT_UNSIGNED)
+			throw new IllegalArgumentException();
 		if(alice) {
 			if(initiator) return deriveKey(secret, A_FRAME_A, connection);
 			else return deriveKey(secret, A_FRAME_B, connection);
@@ -296,6 +325,10 @@ class CryptoComponentImpl implements CryptoComponent {
 	}
 
 	private ErasableKey deriveKey(byte[] secret, byte[] label, long context) {
+		if(secret.length != SECRET_KEY_BYTES)
+			throw new IllegalArgumentException();
+		if(Arrays.equals(secret, BLANK_SECRET))
+			throw new IllegalArgumentException();
 		byte[] key = counterModeKdf(secret, label, context);
 		return new ErasableKeyImpl(key, SECRET_KEY_ALGO);
 	}
@@ -424,9 +457,10 @@ class CryptoComponentImpl implements CryptoComponent {
 	// Key derivation function based on a block cipher in CTR mode - see
 	// NIST SP 800-108, section 5.1
 	private byte[] counterModeKdf(byte[] secret, byte[] label, long context) {
-		// The secret must be usable as a key
 		if(secret.length != SECRET_KEY_BYTES)
 			throw new IllegalArgumentException();
+		if(Arrays.equals(secret, BLANK_SECRET))
+			throw new IllegalArgumentException();
 		// The label and context must leave a byte free for the counter
 		if(label.length + 4 >= KEY_DERIVATION_IV_BYTES)
 			throw new IllegalArgumentException();
diff --git a/briar-tests/src/net/sf/briar/messaging/simplex/OutgoingSimplexConnectionTest.java b/briar-tests/src/net/sf/briar/messaging/simplex/OutgoingSimplexConnectionTest.java
index 4c2684caa3..83476b8d58 100644
--- a/briar-tests/src/net/sf/briar/messaging/simplex/OutgoingSimplexConnectionTest.java
+++ b/briar-tests/src/net/sf/briar/messaging/simplex/OutgoingSimplexConnectionTest.java
@@ -8,6 +8,7 @@ import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
 
 import java.io.ByteArrayOutputStream;
 import java.util.Arrays;
+import java.util.Random;
 import java.util.concurrent.Executor;
 import java.util.concurrent.Executors;
 
@@ -78,6 +79,7 @@ public class OutgoingSimplexConnectionTest extends BriarTestCase {
 		messageId = new MessageId(TestUtils.getRandomId());
 		transportId = new TransportId(TestUtils.getRandomId());
 		secret = new byte[32];
+		new Random().nextBytes(secret);
 	}
 
 	@Test
-- 
GitLab