From 57e6f2ea9c5ff392fd34559f8b08416292b7439f Mon Sep 17 00:00:00 2001
From: akwizgran <michael@briarproject.org>
Date: Wed, 28 Mar 2018 10:51:30 +0100
Subject: [PATCH] Unit tests for removing unbound keys.

---
 .../bramble/db/DatabaseComponentImplTest.java |  16 +-
 .../bramble/db/JdbcDatabaseTest.java          | 113 ++++++++++--
 .../TransportKeyManagerImplTest.java          | 164 +++++++++++-------
 3 files changed, 213 insertions(+), 80 deletions(-)

diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java
index 4832d9d8bc..61a4991628 100644
--- a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java
+++ b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java
@@ -778,13 +778,13 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
 			// endTransaction()
 			oneOf(database).commitTransaction(txn);
 			// Check whether the transport is in the DB (which it's not)
-			exactly(5).of(database).startTransaction();
+			exactly(6).of(database).startTransaction();
 			will(returnValue(txn));
 			exactly(2).of(database).containsContact(txn, contactId);
 			will(returnValue(true));
-			exactly(5).of(database).containsTransport(txn, transportId);
+			exactly(6).of(database).containsTransport(txn, transportId);
 			will(returnValue(false));
-			exactly(5).of(database).abortTransaction(txn);
+			exactly(6).of(database).abortTransaction(txn);
 		}});
 		DatabaseComponent db = createDatabaseComponent(database, eventBus,
 				shutdown);
@@ -839,6 +839,16 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase {
 			db.endTransaction(transaction);
 		}
 
+		transaction = db.startTransaction(false);
+		try {
+			db.removeTransportKeys(transaction, transportId, keySetId);
+			fail();
+		} catch (NoSuchTransportException expected) {
+			// Expected
+		} finally {
+			db.endTransaction(transaction);
+		}
+
 		transaction = db.startTransaction(false);
 		try {
 			db.setReorderingWindow(transaction, keySetId, transportId, 0, 0,
diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java
index 4237af191b..6babc64594 100644
--- a/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java
+++ b/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java
@@ -91,7 +91,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 	private final Message message;
 	private final TransportId transportId;
 	private final ContactId contactId;
-	private final KeySetId keySetId;
+	private final KeySetId keySetId, keySetId1;
 
 	JdbcDatabaseTest() throws Exception {
 		groupId = new GroupId(getRandomId());
@@ -108,6 +108,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		transportId = new TransportId("id");
 		contactId = new ContactId(1);
 		keySetId = new KeySetId(1);
+		keySetId1 = new KeySetId(2);
 	}
 
 	protected abstract JdbcDatabase createDatabase(DatabaseConfig config,
@@ -667,6 +668,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 	@Test
 	public void testTransportKeys() throws Exception {
 		TransportKeys keys = createTransportKeys();
+		TransportKeys keys1 = createTransportKeys();
 
 		Database<Connection> db = open(false);
 		Connection txn = db.startTransaction();
@@ -680,23 +682,20 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 				true, true));
 		db.addTransport(txn, transportId, 123);
 		assertEquals(keySetId, db.addTransportKeys(txn, contactId, keys));
+		assertEquals(keySetId1, db.addTransportKeys(txn, contactId, keys1));
 
 		// Retrieve the transport keys
-		Collection<KeySet> newKeys = db.getTransportKeys(txn, transportId);
-		assertEquals(1, newKeys.size());
-		KeySet ks = newKeys.iterator().next();
-		assertEquals(keySetId, ks.getKeySetId());
-		assertEquals(contactId, ks.getContactId());
-		TransportKeys k = ks.getTransportKeys();
-		assertEquals(transportId, k.getTransportId());
-		assertKeysEquals(keys.getPreviousIncomingKeys(),
-				k.getPreviousIncomingKeys());
-		assertKeysEquals(keys.getCurrentIncomingKeys(),
-				k.getCurrentIncomingKeys());
-		assertKeysEquals(keys.getNextIncomingKeys(),
-				k.getNextIncomingKeys());
-		assertKeysEquals(keys.getCurrentOutgoingKeys(),
-				k.getCurrentOutgoingKeys());
+		Collection<KeySet> allKeys = db.getTransportKeys(txn, transportId);
+		assertEquals(2, allKeys.size());
+		for (KeySet ks : allKeys) {
+			assertEquals(contactId, ks.getContactId());
+			if (ks.getKeySetId().equals(keySetId)) {
+				assertKeysEquals(keys, ks.getTransportKeys());
+			} else {
+				assertEquals(keySetId1, ks.getKeySetId());
+				assertKeysEquals(keys1, ks.getTransportKeys());
+			}
+		}
 
 		// Removing the contact should remove the transport keys
 		db.removeContact(txn, contactId);
@@ -706,6 +705,88 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase {
 		db.close();
 	}
 
+	@Test
+	public void testUnboundTransportKeys() throws Exception {
+		TransportKeys keys = createTransportKeys();
+		TransportKeys keys1 = createTransportKeys();
+
+		Database<Connection> db = open(false);
+		Connection txn = db.startTransaction();
+
+		// Initially there should be no transport keys in the database
+		assertEquals(emptyList(), db.getTransportKeys(txn, transportId));
+
+		// Add the contact, the transport and the unbound transport keys
+		db.addLocalAuthor(txn, localAuthor);
+		assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(),
+				true, true));
+		db.addTransport(txn, transportId, 123);
+		assertEquals(keySetId, db.addTransportKeys(txn, null, keys));
+		assertEquals(keySetId1, db.addTransportKeys(txn, null, keys1));
+
+		// Retrieve the transport keys
+		Collection<KeySet> allKeys = db.getTransportKeys(txn, transportId);
+		assertEquals(2, allKeys.size());
+		for (KeySet ks : allKeys) {
+			assertNull(ks.getContactId());
+			if (ks.getKeySetId().equals(keySetId)) {
+				assertKeysEquals(keys, ks.getTransportKeys());
+			} else {
+				assertEquals(keySetId1, ks.getKeySetId());
+				assertKeysEquals(keys1, ks.getTransportKeys());
+			}
+		}
+
+		// Bind the first set of transport keys
+		db.bindTransportKeys(txn, contactId, transportId, keySetId);
+
+		// Retrieve the keys again - the first set should be bound
+		allKeys = db.getTransportKeys(txn, transportId);
+		assertEquals(2, allKeys.size());
+		for (KeySet ks : allKeys) {
+			if (ks.getKeySetId().equals(keySetId)) {
+				assertEquals(contactId, ks.getContactId());
+				assertKeysEquals(keys, ks.getTransportKeys());
+			} else {
+				assertEquals(keySetId1, ks.getKeySetId());
+				assertNull(ks.getContactId());
+				assertKeysEquals(keys1, ks.getTransportKeys());
+			}
+		}
+
+		// Remove the unbound transport keys
+		db.removeTransportKeys(txn, transportId, keySetId1);
+
+		// Retrieve the keys again - the second set should be gone
+		allKeys = db.getTransportKeys(txn, transportId);
+		assertEquals(1, allKeys.size());
+		KeySet ks = allKeys.iterator().next();
+		assertEquals(keySetId, ks.getKeySetId());
+		assertEquals(contactId, ks.getContactId());
+		assertKeysEquals(keys, ks.getTransportKeys());
+
+		// Removing the transport should remove the remaining transport keys
+		db.removeTransport(txn, transportId);
+		assertEquals(emptyList(), db.getTransportKeys(txn, transportId));
+
+		db.commitTransaction(txn);
+		db.close();
+	}
+
+	private void assertKeysEquals(TransportKeys expected,
+			TransportKeys actual) {
+		assertEquals(expected.getTransportId(), actual.getTransportId());
+		assertEquals(expected.getRotationPeriod(), actual.getRotationPeriod());
+		assertKeysEquals(expected.getPreviousIncomingKeys(),
+				actual.getPreviousIncomingKeys());
+		assertKeysEquals(expected.getCurrentIncomingKeys(),
+				actual.getCurrentIncomingKeys());
+		assertKeysEquals(expected.getNextIncomingKeys(),
+				actual.getNextIncomingKeys());
+		assertKeysEquals(expected.getCurrentOutgoingKeys(),
+				actual.getCurrentOutgoingKeys());
+	}
+
 	private void assertKeysEquals(IncomingKeys expected, IncomingKeys actual) {
 		assertArrayEquals(expected.getTagKey().getBytes(),
 				actual.getTagKey().getBytes());
diff --git a/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java
index a57ce4e742..afc5dfed2d 100644
--- a/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java
+++ b/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java
@@ -203,27 +203,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 				MAX_32_BIT_UNSIGNED + 1);
 		Transaction txn = new Transaction(null, false);
 
-		context.checking(new Expectations() {{
-			oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
-					1000, alice);
-			will(returnValue(transportKeys));
-			// Get the current time (the start of rotation period 1000)
-			oneOf(clock).currentTimeMillis();
-			will(returnValue(rotationPeriodLength * 1000));
-			// Encode the tags (3 sets)
-			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
-				exactly(3).of(transportCrypto).encodeTag(
-						with(any(byte[].class)), with(tagKey),
-						with(PROTOCOL_VERSION), with(i));
-				will(new EncodeTagAction());
-			}
-			// Rotate the transport keys (the keys are unaffected)
-			oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
-			will(returnValue(transportKeys));
-			// Save the keys
-			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
-			will(returnValue(keySetId));
-		}});
+		expectAddContactNoRotation(alice, transportKeys, txn);
 
 		TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
 				db, transportCrypto, dbExecutor, scheduler, clock, transportId,
@@ -243,26 +223,9 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 				MAX_32_BIT_UNSIGNED);
 		Transaction txn = new Transaction(null, false);
 
+		expectAddContactNoRotation(alice, transportKeys, txn);
+
 		context.checking(new Expectations() {{
-			oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
-					1000, alice);
-			will(returnValue(transportKeys));
-			// Get the current time (the start of rotation period 1000)
-			oneOf(clock).currentTimeMillis();
-			will(returnValue(rotationPeriodLength * 1000));
-			// Encode the tags (3 sets)
-			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
-				exactly(3).of(transportCrypto).encodeTag(
-						with(any(byte[].class)), with(tagKey),
-						with(PROTOCOL_VERSION), with(i));
-				will(new EncodeTagAction());
-			}
-			// Rotate the transport keys (the keys are unaffected)
-			oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
-			will(returnValue(transportKeys));
-			// Save the keys
-			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
-			will(returnValue(keySetId));
 			// Increment the stream counter
 			oneOf(db).incrementStreamCounter(txn, contactId, transportId, 1000);
 		}});
@@ -294,27 +257,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 		TransportKeys transportKeys = createTransportKeys(1000, 0);
 		Transaction txn = new Transaction(null, false);
 
-		context.checking(new Expectations() {{
-			oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
-					1000, alice);
-			will(returnValue(transportKeys));
-			// Get the current time (the start of rotation period 1000)
-			oneOf(clock).currentTimeMillis();
-			will(returnValue(rotationPeriodLength * 1000));
-			// Encode the tags (3 sets)
-			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
-				exactly(3).of(transportCrypto).encodeTag(
-						with(any(byte[].class)), with(tagKey),
-						with(PROTOCOL_VERSION), with(i));
-				will(new EncodeTagAction());
-			}
-			// Rotate the transport keys (the keys are unaffected)
-			oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
-			will(returnValue(transportKeys));
-			// Save the keys
-			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
-			will(returnValue(keySetId));
-		}});
+		expectAddContactNoRotation(alice, transportKeys, txn);
 
 		TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
 				db, transportCrypto, dbExecutor, scheduler, clock, transportId,
@@ -323,6 +266,7 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 		long timestamp = rotationPeriodLength * 1000;
 		transportKeyManager.addContact(txn, contactId, masterKey, timestamp,
 				alice);
+		// The tag should not be recognised
 		assertNull(transportKeyManager.getStreamContext(txn,
 				new byte[TAG_LENGTH]));
 	}
@@ -466,6 +410,104 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase {
 		transportKeyManager.start(txn);
 	}
 
+	@Test
+	public void testTagsAreEncodedWhenKeysAreBound() throws Exception {
+		boolean alice = random.nextBoolean();
+		TransportKeys transportKeys = createTransportKeys(1000, 0);
+		Transaction txn = new Transaction(null, false);
+
+		context.checking(new Expectations() {{
+			oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
+					1000, alice);
+			will(returnValue(transportKeys));
+			// Get the current time (the start of rotation period 1000)
+			oneOf(clock).currentTimeMillis();
+			will(returnValue(rotationPeriodLength * 1000));
+			// Rotate the transport keys (the keys are unaffected)
+			oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
+			will(returnValue(transportKeys));
+			// Save the unbound keys
+			oneOf(db).addTransportKeys(txn, null, transportKeys);
+			will(returnValue(keySetId));
+			// When the keys are bound, encode the tags (3 sets)
+			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
+				exactly(3).of(transportCrypto).encodeTag(
+						with(any(byte[].class)), with(tagKey),
+						with(PROTOCOL_VERSION), with(i));
+				will(new EncodeTagAction());
+			}
+			// Save the key binding
+			oneOf(db).bindTransportKeys(txn, contactId, transportId, keySetId);
+		}});
+
+		TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
+				db, transportCrypto, dbExecutor, scheduler, clock, transportId,
+				maxLatency);
+		// The timestamp is at the start of rotation period 1000
+		long timestamp = rotationPeriodLength * 1000;
+		assertEquals(keySetId, transportKeyManager.addUnboundKeys(txn,
+				masterKey, timestamp, alice));
+		transportKeyManager.bindKeys(txn, contactId, keySetId);
+	}
+
+	@Test
+	public void testRemovingUnboundKeys() throws Exception {
+		boolean alice = random.nextBoolean();
+		TransportKeys transportKeys = createTransportKeys(1000, 0);
+		Transaction txn = new Transaction(null, false);
+
+		context.checking(new Expectations() {{
+			oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
+					1000, alice);
+			will(returnValue(transportKeys));
+			// Get the current time (the start of rotation period 1000)
+			oneOf(clock).currentTimeMillis();
+			will(returnValue(rotationPeriodLength * 1000));
+			// Rotate the transport keys (the keys are unaffected)
+			oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
+			will(returnValue(transportKeys));
+			// Save the unbound keys
+			oneOf(db).addTransportKeys(txn, null, transportKeys);
+			will(returnValue(keySetId));
+			// Remove the unbound keys
+			oneOf(db).removeTransportKeys(txn, transportId, keySetId);
+		}});
+
+		TransportKeyManager transportKeyManager = new TransportKeyManagerImpl(
+				db, transportCrypto, dbExecutor, scheduler, clock, transportId,
+				maxLatency);
+		// The timestamp is at the start of rotation period 1000
+		long timestamp = rotationPeriodLength * 1000;
+		assertEquals(keySetId, transportKeyManager.addUnboundKeys(txn,
+				masterKey, timestamp, alice));
+		transportKeyManager.removeKeys(txn, keySetId);
+	}
+
+	private void expectAddContactNoRotation(boolean alice,
+			TransportKeys transportKeys, Transaction txn) throws Exception {
+		context.checking(new Expectations() {{
+			oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey,
+					1000, alice);
+			will(returnValue(transportKeys));
+			// Get the current time (the start of rotation period 1000)
+			oneOf(clock).currentTimeMillis();
+			will(returnValue(rotationPeriodLength * 1000));
+			// Encode the tags (3 sets)
+			for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) {
+				exactly(3).of(transportCrypto).encodeTag(
+						with(any(byte[].class)), with(tagKey),
+						with(PROTOCOL_VERSION), with(i));
+				will(new EncodeTagAction());
+			}
+			// Rotate the transport keys (the keys are unaffected)
+			oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000);
+			will(returnValue(transportKeys));
+			// Save the keys
+			oneOf(db).addTransportKeys(txn, contactId, transportKeys);
+			will(returnValue(keySetId));
+		}});
+	}
+
 	private TransportKeys createTransportKeys(long rotationPeriod,
 			long streamCounter) {
 		IncomingKeys inPrev = new IncomingKeys(tagKey, headerKey,
-- 
GitLab