diff --git a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java index d521531b0a5ef9641271d47c2968758d4d5005d3..63bcdbc778a5cd6bb6ed7e48c96d6ee3737dc874 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/transport/TransportKeyManagerImpl.java @@ -355,6 +355,16 @@ class TransportKeyManagerImpl implements TransportKeyManager { db.setReorderingWindow(txn, tagCtx.keySetId, transportId, inKeys.getRotationPeriod(), window.getBase(), window.getBitmap()); + // If the outgoing keys are inactive, activate them + MutableKeySet ks = keys.get(tagCtx.keySetId); + MutableOutgoingKeys outKeys = + ks.getTransportKeys().getCurrentOutgoingKeys(); + if (!outKeys.isActive()) { + LOG.info("Activating outgoing keys"); + outKeys.activate(); + considerReplacingOutgoingKeys(ks); + db.setTransportKeysActive(txn, transportId, tagCtx.keySetId); + } return ctx; } finally { lock.unlock(); diff --git a/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java index 84d6c34650050156b7203f4635ebb9079203062d..21350ae6d7414409091bb5fafaa81b6b7fb87a1b 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/transport/TransportKeyManagerImplTest.java @@ -286,9 +286,10 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { public void testTagIsNotRecognisedTwice() throws Exception { boolean alice = random.nextBoolean(); TransportKeys transportKeys = createTransportKeys(1000, 0, true); + Transaction txn = new Transaction(null, false); + // Keep a copy of the tags List<byte[]> tags = new ArrayList<>(); - Transaction txn = new Transaction(null, false); context.checking(new Expectations() {{ oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, @@ -418,19 +419,9 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { TransportKeys transportKeys = createTransportKeys(1000, 0, false); Transaction txn = new Transaction(null, false); + expectAddUnboundKeysNoRotation(alice, transportKeys, txn); + context.checking(new Expectations() {{ - oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, - 1000, alice, false); - will(returnValue(transportKeys)); - // Get the current time (the start of rotation period 1000) - oneOf(clock).currentTimeMillis(); - will(returnValue(rotationPeriodLength * 1000)); - // Rotate the transport keys (the keys are unaffected) - oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000); - will(returnValue(transportKeys)); - // Save the unbound keys - oneOf(db).addTransportKeys(txn, null, transportKeys); - will(returnValue(keySetId)); // When the keys are bound, encode the tags (3 sets) for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { exactly(3).of(transportCrypto).encodeTag( @@ -470,7 +461,74 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { assertEquals(transportId, ctx.getTransportId()); assertEquals(tagKey, ctx.getTagKey()); assertEquals(headerKey, ctx.getHeaderKey()); - assertEquals(0, ctx.getStreamNumber()); + assertEquals(0L, ctx.getStreamNumber()); + } + + @Test + public void testRecognisingTagActivatesOutgoingKeys() throws Exception { + boolean alice = random.nextBoolean(); + TransportKeys transportKeys = createTransportKeys(1000, 0, false); + Transaction txn = new Transaction(null, false); + + // Keep a copy of the tags + List<byte[]> tags = new ArrayList<>(); + + expectAddUnboundKeysNoRotation(alice, transportKeys, txn); + + context.checking(new Expectations() {{ + // When the keys are bound, encode the tags (3 sets) + for (long i = 0; i < REORDERING_WINDOW_SIZE; i++) { + exactly(3).of(transportCrypto).encodeTag( + with(any(byte[].class)), with(tagKey), + with(PROTOCOL_VERSION), with(i)); + will(new EncodeTagAction(tags)); + } + // Save the key binding + oneOf(db).bindTransportKeys(txn, contactId, transportId, keySetId); + // Encode a new tag after sliding the window + oneOf(transportCrypto).encodeTag(with(any(byte[].class)), + with(tagKey), with(PROTOCOL_VERSION), + with((long) REORDERING_WINDOW_SIZE)); + will(new EncodeTagAction(tags)); + // Save the reordering window (previous rotation period, base 1) + oneOf(db).setReorderingWindow(txn, keySetId, transportId, 999, + 1, new byte[REORDERING_WINDOW_SIZE / 8]); + // Activate the keys + oneOf(db).setTransportKeysActive(txn, transportId, keySetId); + // Increment the stream counter + oneOf(db).incrementStreamCounter(txn, transportId, keySetId); + }}); + + TransportKeyManager transportKeyManager = new TransportKeyManagerImpl( + db, transportCrypto, dbExecutor, scheduler, clock, transportId, + maxLatency); + // The timestamp is at the start of rotation period 1000 + long timestamp = rotationPeriodLength * 1000; + assertEquals(keySetId, transportKeyManager.addUnboundKeys(txn, + masterKey, timestamp, alice)); + transportKeyManager.bindKeys(txn, contactId, keySetId); + // The keys are inactive so no stream context should be returned + assertFalse(transportKeyManager.canSendOutgoingStreams(contactId)); + assertNull(transportKeyManager.getStreamContext(txn, contactId)); + // Recognising an incoming tag should activate the outgoing keys + assertEquals(REORDERING_WINDOW_SIZE * 3, tags.size()); + byte[] tag = tags.get(0); + StreamContext ctx = transportKeyManager.getStreamContext(txn, tag); + assertNotNull(ctx); + assertEquals(contactId, ctx.getContactId()); + assertEquals(transportId, ctx.getTransportId()); + assertEquals(tagKey, ctx.getTagKey()); + assertEquals(headerKey, ctx.getHeaderKey()); + assertEquals(0L, ctx.getStreamNumber()); + // The keys are active so a stream context should be returned + assertTrue(transportKeyManager.canSendOutgoingStreams(contactId)); + ctx = transportKeyManager.getStreamContext(txn, contactId); + assertNotNull(ctx); + assertEquals(contactId, ctx.getContactId()); + assertEquals(transportId, ctx.getTransportId()); + assertEquals(tagKey, ctx.getTagKey()); + assertEquals(headerKey, ctx.getHeaderKey()); + assertEquals(0L, ctx.getStreamNumber()); } @Test @@ -479,19 +537,9 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { TransportKeys transportKeys = createTransportKeys(1000, 0, false); Transaction txn = new Transaction(null, false); + expectAddUnboundKeysNoRotation(alice, transportKeys, txn); + context.checking(new Expectations() {{ - oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, - 1000, alice, false); - will(returnValue(transportKeys)); - // Get the current time (the start of rotation period 1000) - oneOf(clock).currentTimeMillis(); - will(returnValue(rotationPeriodLength * 1000)); - // Rotate the transport keys (the keys are unaffected) - oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000); - will(returnValue(transportKeys)); - // Save the unbound keys - oneOf(db).addTransportKeys(txn, null, transportKeys); - will(returnValue(keySetId)); // Remove the unbound keys oneOf(db).removeTransportKeys(txn, transportId, keySetId); }}); @@ -533,6 +581,24 @@ public class TransportKeyManagerImplTest extends BrambleMockTestCase { }}); } + private void expectAddUnboundKeysNoRotation(boolean alice, + TransportKeys transportKeys, Transaction txn) throws Exception { + context.checking(new Expectations() {{ + oneOf(transportCrypto).deriveTransportKeys(transportId, masterKey, + 1000, alice, false); + will(returnValue(transportKeys)); + // Get the current time (the start of rotation period 1000) + oneOf(clock).currentTimeMillis(); + will(returnValue(rotationPeriodLength * 1000)); + // Rotate the transport keys (the keys are unaffected) + oneOf(transportCrypto).rotateTransportKeys(transportKeys, 1000); + will(returnValue(transportKeys)); + // Save the unbound keys + oneOf(db).addTransportKeys(txn, null, transportKeys); + will(returnValue(keySetId)); + }}); + } + private TransportKeys createTransportKeys(long rotationPeriod, long streamCounter, boolean active) { IncomingKeys inPrev = new IncomingKeys(tagKey, headerKey,