From 9220bb3426e6b3980e5d16769e8e6430f6dfe1c1 Mon Sep 17 00:00:00 2001
From: akwizgran <akwizgran@users.sourceforge.net>
Date: Tue, 15 Nov 2011 17:19:11 +0000
Subject: [PATCH] Key derivation function based on NIST SP 800-108.

---
 .../sf/briar/api/crypto/CryptoComponent.java  |   8 +-
 .../sf/briar/crypto/CryptoComponentImpl.java  | 111 +++++++++++-------
 .../net/sf/briar/ProtocolIntegrationTest.java |   2 +-
 .../sf/briar/db/DatabaseComponentTest.java    |   4 +-
 test/net/sf/briar/db/H2DatabaseTest.java      |   4 +-
 .../ConnectionRecogniserImplTest.java         |   2 +-
 .../briar/transport/ConnectionWriterTest.java |   2 +-
 .../briar/transport/FrameReadWriteTest.java   |   2 +-
 .../batch/BatchConnectionReadWriteTest.java   |   4 +-
 9 files changed, 83 insertions(+), 56 deletions(-)

diff --git a/api/net/sf/briar/api/crypto/CryptoComponent.java b/api/net/sf/briar/api/crypto/CryptoComponent.java
index 2b01765b7d..0c6b7f9ed0 100644
--- a/api/net/sf/briar/api/crypto/CryptoComponent.java
+++ b/api/net/sf/briar/api/crypto/CryptoComponent.java
@@ -9,11 +9,13 @@ import javax.crypto.Mac;
 
 public interface CryptoComponent {
 
-	ErasableKey deriveFrameKey(byte[] source, boolean initiator);
+	ErasableKey deriveFrameKey(byte[] secret, boolean initiator);
 
-	ErasableKey deriveIvKey(byte[] source, boolean initiator);
+	ErasableKey deriveIvKey(byte[] secret, boolean initiator);
 
-	ErasableKey deriveMacKey(byte[] source, boolean initiator);
+	ErasableKey deriveMacKey(byte[] secret, boolean initiator);
+
+	byte[] deriveNextSecret(byte[] secret, long connection);
 
 	KeyPair generateKeyPair();
 
diff --git a/components/net/sf/briar/crypto/CryptoComponentImpl.java b/components/net/sf/briar/crypto/CryptoComponentImpl.java
index 670245c2d4..d82d7f924c 100644
--- a/components/net/sf/briar/crypto/CryptoComponentImpl.java
+++ b/components/net/sf/briar/crypto/CryptoComponentImpl.java
@@ -1,22 +1,21 @@
 package net.sf.briar.crypto;
 
-import java.io.UnsupportedEncodingException;
+import java.security.GeneralSecurityException;
 import java.security.KeyPair;
 import java.security.KeyPairGenerator;
-import java.security.NoSuchAlgorithmException;
-import java.security.NoSuchProviderException;
 import java.security.SecureRandom;
 import java.security.Security;
 import java.security.Signature;
 
 import javax.crypto.Cipher;
 import javax.crypto.Mac;
-import javax.crypto.NoSuchPaddingException;
+import javax.crypto.spec.IvParameterSpec;
 
 import net.sf.briar.api.crypto.CryptoComponent;
 import net.sf.briar.api.crypto.ErasableKey;
 import net.sf.briar.api.crypto.KeyParser;
 import net.sf.briar.api.crypto.MessageDigest;
+import net.sf.briar.util.ByteUtils;
 
 import org.bouncycastle.jce.provider.BouncyCastleProvider;
 
@@ -30,10 +29,21 @@ class CryptoComponentImpl implements CryptoComponent {
 	private static final int KEY_PAIR_BITS = 256;
 	private static final String FRAME_CIPHER_ALGO = "AES/CTR/NoPadding";
 	private static final String SECRET_KEY_ALGO = "AES";
-	private static final int SECRET_KEY_BYTES = 32;
+	private static final int SECRET_KEY_BYTES = 32; // 256 bits
 	private static final String IV_CIPHER_ALGO = "AES/ECB/NoPadding";
 	private static final String MAC_ALGO = "HMacSHA256";
 	private static final String SIGNATURE_ALGO = "ECDSA";
+	private static final String KEY_DERIVATION_ALGO = "AES/CTR/NoPadding";
+	private static final int KEY_DERIVATION_IV_BYTES = 16; // 128 bits
+
+	// Context strings for key derivation
+	private static final byte[] FRAME_I = { 'F', 'R', 'A', 'M', 'E', '_', 'I' };
+	private static final byte[] FRAME_R = { 'F', 'R', 'A', 'M', 'E', '_', 'R' };
+	private static final byte[] IV_I = { 'I', 'V', '_', 'I' };
+	private static final byte[] IV_R = { 'I', 'V', '_', 'R' };
+	private static final byte[] MAC_I = { 'M', 'A', 'C', '_', 'I' };
+	private static final byte[] MAC_R = { 'M', 'A', 'C', '_', 'R' };
+	private static final byte[] NEXT = { 'N', 'E', 'X', 'T' };
 
 	private final KeyParser keyParser;
 	private final KeyPairGenerator keyPairGenerator;
@@ -46,38 +56,67 @@ class CryptoComponentImpl implements CryptoComponent {
 			keyPairGenerator = KeyPairGenerator.getInstance(KEY_PAIR_ALGO,
 					PROVIDER);
 			keyPairGenerator.initialize(KEY_PAIR_BITS);
-		} catch(NoSuchAlgorithmException e) {
-			throw new RuntimeException(e);
-		} catch(NoSuchProviderException e) {
+		} catch(GeneralSecurityException e) {
 			throw new RuntimeException(e);
 		}
 	}
 
-	public ErasableKey deriveFrameKey(byte[] source, boolean initiator) {
-		if(initiator) return deriveKey("FRAME_I", source);
-		else return deriveKey("FRAME_R", source);
+	public ErasableKey deriveFrameKey(byte[] secret, boolean initiator) {
+		if(initiator) return deriveKey(secret, FRAME_I);
+		else return deriveKey(secret, FRAME_R);
+	}
+
+	public ErasableKey deriveIvKey(byte[] secret, boolean initiator) {
+		if(initiator) return deriveKey(secret, IV_I);
+		else return deriveKey(secret, IV_R);
 	}
 
-	public ErasableKey deriveIvKey(byte[] source, boolean initiator) {
-		if(initiator) return deriveKey("IV_I", source);
-		else return deriveKey("IV_R", source);
+	public ErasableKey deriveMacKey(byte[] secret, boolean initiator) {
+		if(initiator) return deriveKey(secret, MAC_I);
+		else return deriveKey(secret, MAC_R);
 	}
 
-	public ErasableKey deriveMacKey(byte[] source, boolean initiator) {
-		if(initiator) return deriveKey("MAC_I", source);
-		else return deriveKey("MAC_R", source);
+	private ErasableKey deriveKey(byte[] secret, byte[] context) {
+		byte[] key = counterModeKdf(secret, context);
+		return new ErasableKeyImpl(key, SECRET_KEY_ALGO);
 	}
 
-	private ErasableKey deriveKey(String name, byte[] source) {
-		MessageDigest digest = getMessageDigest();
-		assert digest.getDigestLength() == SECRET_KEY_BYTES;
+	// Key derivation function based on a block cipher in CTR mode - see
+	// NIST SP 800-108, section 5.1
+	private byte[] counterModeKdf(byte[] secret, byte[] context) {
+		// The secret must be usable as a key
+		if(secret.length != SECRET_KEY_BYTES)
+			throw new IllegalArgumentException();
+		ErasableKey key = new ErasableKeyImpl(secret, SECRET_KEY_ALGO);
+		// The context must leave four bytes free for the length
+		if(context.length + 4 > SECRET_KEY_BYTES)
+			throw new IllegalArgumentException();
+		byte[] input = new byte[SECRET_KEY_BYTES];
+		// The initial bytes of the input are the context
+		System.arraycopy(context, 0, input, 0, context.length);
+		// The final bytes of the input are the length as a big-endian uint32
+		ByteUtils.writeUint32(context.length, input, input.length - 4);
+		// Initialise the counter to zero
+		byte[] zero = new byte[KEY_DERIVATION_IV_BYTES];
+		IvParameterSpec iv = new IvParameterSpec(zero);
 		try {
-			digest.update(name.getBytes("UTF-8"));
-		} catch(UnsupportedEncodingException e) {
+			Cipher cipher = Cipher.getInstance(KEY_DERIVATION_ALGO, PROVIDER);
+			cipher.init(Cipher.ENCRYPT_MODE, key, iv);
+			byte[] output = cipher.doFinal(input);
+			assert output.length == SECRET_KEY_BYTES;
+			return output;
+		} catch(GeneralSecurityException e) {
 			throw new RuntimeException(e);
 		}
-		digest.update(source);
-		return new ErasableKeyImpl(digest.digest(), SECRET_KEY_ALGO);
+	}
+
+	public byte[] deriveNextSecret(byte[] secret, long connection) {
+		if(connection < 0 || connection > ByteUtils.MAX_32_BIT_UNSIGNED)
+			throw new IllegalArgumentException();
+		byte[] context = new byte[NEXT.length + 4];
+		System.arraycopy(NEXT, 0, context, 0, NEXT.length);
+		ByteUtils.writeUint32(connection, context, NEXT.length);
+		return counterModeKdf(secret, context);
 	}
 
 	public KeyPair generateKeyPair() {
@@ -93,11 +132,7 @@ class CryptoComponentImpl implements CryptoComponent {
 	public Cipher getFrameCipher() {
 		try {
 			return Cipher.getInstance(FRAME_CIPHER_ALGO, PROVIDER);
-		} catch(NoSuchAlgorithmException e) {
-			throw new RuntimeException(e);
-		} catch(NoSuchPaddingException e) {
-			throw new RuntimeException(e);
-		} catch(NoSuchProviderException e) {
+		} catch(GeneralSecurityException e) {
 			throw new RuntimeException(e);
 		}
 	}
@@ -105,11 +140,7 @@ class CryptoComponentImpl implements CryptoComponent {
 	public Cipher getIvCipher() {
 		try {
 			return Cipher.getInstance(IV_CIPHER_ALGO, PROVIDER);
-		} catch(NoSuchAlgorithmException e) {
-			throw new RuntimeException(e);
-		} catch(NoSuchPaddingException e) {
-			throw new RuntimeException(e);
-		} catch(NoSuchProviderException e) {
+		} catch(GeneralSecurityException e) {
 			throw new RuntimeException(e);
 		}
 	}
@@ -121,9 +152,7 @@ class CryptoComponentImpl implements CryptoComponent {
 	public Mac getMac() {
 		try {
 			return Mac.getInstance(MAC_ALGO, PROVIDER);
-		} catch(NoSuchAlgorithmException e) {
-			throw new RuntimeException(e);
-		} catch(NoSuchProviderException e) {
+		} catch(GeneralSecurityException e) {
 			throw new RuntimeException(e);
 		}
 	}
@@ -132,9 +161,7 @@ class CryptoComponentImpl implements CryptoComponent {
 		try {
 			return new DoubleDigest(java.security.MessageDigest.getInstance(
 					DIGEST_ALGO, PROVIDER));
-		} catch(NoSuchAlgorithmException e) {
-			throw new RuntimeException(e);
-		} catch(NoSuchProviderException e) {
+		} catch(GeneralSecurityException e) {
 			throw new RuntimeException(e);
 		}
 	}
@@ -147,9 +174,7 @@ class CryptoComponentImpl implements CryptoComponent {
 	public Signature getSignature() {
 		try {
 			return Signature.getInstance(SIGNATURE_ALGO, PROVIDER);
-		} catch(NoSuchAlgorithmException e) {
-			throw new RuntimeException(e);
-		} catch(NoSuchProviderException e) {
+		} catch(GeneralSecurityException e) {
 			throw new RuntimeException(e);
 		}
 	}
diff --git a/test/net/sf/briar/ProtocolIntegrationTest.java b/test/net/sf/briar/ProtocolIntegrationTest.java
index 907fc63d8f..001407cdfe 100644
--- a/test/net/sf/briar/ProtocolIntegrationTest.java
+++ b/test/net/sf/briar/ProtocolIntegrationTest.java
@@ -99,7 +99,7 @@ public class ProtocolIntegrationTest extends TestCase {
 		assertEquals(crypto.getMessageDigest().getDigestLength(),
 				UniqueId.LENGTH);
 		Random r = new Random();
-		aliceToBobSecret = new byte[123];
+		aliceToBobSecret = new byte[32];
 		r.nextBytes(aliceToBobSecret);
 		// Create two groups: one restricted, one unrestricted
 		GroupFactory groupFactory = i.getInstance(GroupFactory.class);
diff --git a/test/net/sf/briar/db/DatabaseComponentTest.java b/test/net/sf/briar/db/DatabaseComponentTest.java
index 763bbf4bb7..178f76d179 100644
--- a/test/net/sf/briar/db/DatabaseComponentTest.java
+++ b/test/net/sf/briar/db/DatabaseComponentTest.java
@@ -97,9 +97,9 @@ public abstract class DatabaseComponentTest extends TestCase {
 				properties);
 		transports = Collections.singletonList(transport);
 		Random r = new Random();
-		inSecret = new byte[123];
+		inSecret = new byte[32];
 		r.nextBytes(inSecret);
-		outSecret = new byte[123];
+		outSecret = new byte[32];
 		r.nextBytes(outSecret);
 	}
 
diff --git a/test/net/sf/briar/db/H2DatabaseTest.java b/test/net/sf/briar/db/H2DatabaseTest.java
index ac3fee99c4..48a7532dc4 100644
--- a/test/net/sf/briar/db/H2DatabaseTest.java
+++ b/test/net/sf/briar/db/H2DatabaseTest.java
@@ -123,9 +123,9 @@ public class H2DatabaseTest extends TestCase {
 		remoteTransports = Collections.singletonList(remoteTransport);
 		subscriptions = Collections.singletonMap(group, 0L);
 		Random r = new Random();
-		inSecret = new byte[123];
+		inSecret = new byte[32];
 		r.nextBytes(inSecret);
-		outSecret = new byte[123];
+		outSecret = new byte[32];
 		r.nextBytes(outSecret);
 	}
 
diff --git a/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java b/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java
index 31eab8d9bd..d9cf7b8db5 100644
--- a/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java
+++ b/test/net/sf/briar/transport/ConnectionRecogniserImplTest.java
@@ -43,7 +43,7 @@ public class ConnectionRecogniserImplTest extends TestCase {
 		Injector i = Guice.createInjector(new CryptoModule());
 		crypto = i.getInstance(CryptoComponent.class);
 		contactId = new ContactId(1);
-		inSecret = new byte[123];
+		inSecret = new byte[32];
 		new Random().nextBytes(inSecret);
 		transportId = new TransportId(TestUtils.getRandomId());
 		localIndex = new TransportIndex(13);
diff --git a/test/net/sf/briar/transport/ConnectionWriterTest.java b/test/net/sf/briar/transport/ConnectionWriterTest.java
index f5ffa882f5..cb27fd9653 100644
--- a/test/net/sf/briar/transport/ConnectionWriterTest.java
+++ b/test/net/sf/briar/transport/ConnectionWriterTest.java
@@ -39,7 +39,7 @@ public class ConnectionWriterTest extends TestCase {
 				new TestDatabaseModule(), new TransportBatchModule(),
 				new TransportModule(), new TransportStreamModule());
 		connectionWriterFactory = i.getInstance(ConnectionWriterFactory.class);
-		outSecret = new byte[123];
+		outSecret = new byte[32];
 		new Random().nextBytes(outSecret);
 	}
 
diff --git a/test/net/sf/briar/transport/FrameReadWriteTest.java b/test/net/sf/briar/transport/FrameReadWriteTest.java
index 5ed26f59a1..b05fff5224 100644
--- a/test/net/sf/briar/transport/FrameReadWriteTest.java
+++ b/test/net/sf/briar/transport/FrameReadWriteTest.java
@@ -44,7 +44,7 @@ public class FrameReadWriteTest extends TestCase {
 		frameCipher = crypto.getFrameCipher();
 		random = new Random();
 		// Since we're sending frames to ourselves, we only need outgoing keys
-		outSecret = new byte[123];
+		outSecret = new byte[32];
 		random.nextBytes(outSecret);
 		ivKey = crypto.deriveIvKey(outSecret, true);
 		frameKey = crypto.deriveFrameKey(outSecret, true);
diff --git a/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java b/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java
index 8ecc2dc580..06d4e7fc6f 100644
--- a/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java
+++ b/test/net/sf/briar/transport/batch/BatchConnectionReadWriteTest.java
@@ -65,9 +65,9 @@ public class BatchConnectionReadWriteTest extends TestCase {
 		transportIndex = new TransportIndex(1);
 		// Create matching secrets for Alice and Bob
 		Random r = new Random();
-		aliceToBobSecret = new byte[123];
+		aliceToBobSecret = new byte[32];
 		r.nextBytes(aliceToBobSecret);
-		bobToAliceSecret = new byte[123];
+		bobToAliceSecret = new byte[32];
 		r.nextBytes(bobToAliceSecret);
 	}
 
-- 
GitLab