diff --git a/briar-core/src/net/sf/briar/messaging/duplex/DuplexConnection.java b/briar-core/src/net/sf/briar/messaging/duplex/DuplexConnection.java index 575244faba4217dad7ffcfa240cb2d1c14cbba99..9b43ebd1c9259bc518d84820201cae6311dbedcc 100644 --- a/briar-core/src/net/sf/briar/messaging/duplex/DuplexConnection.java +++ b/briar-core/src/net/sf/briar/messaging/duplex/DuplexConnection.java @@ -135,6 +135,7 @@ abstract class DuplexConnection implements DatabaseListener { public void eventOccurred(DatabaseEvent e) { if(e instanceof ContactRemovedEvent) { ContactRemovedEvent c = (ContactRemovedEvent) e; + // FIXME: Listeners should not block if(contactId.equals(c.getContactId())) dispose(false, true); } else if(e instanceof GroupMessageAddedEvent) { if(canSendOffer.getAndSet(false)) diff --git a/briar-core/src/net/sf/briar/transport/KeyManagerImpl.java b/briar-core/src/net/sf/briar/transport/KeyManagerImpl.java index 6ab4d1730247cdc4040f9d963bfaedfa85009805..409ca79fa72ba23d76e7b882b3db266b6be27135 100644 --- a/briar-core/src/net/sf/briar/transport/KeyManagerImpl.java +++ b/briar-core/src/net/sf/briar/transport/KeyManagerImpl.java @@ -300,6 +300,7 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { if(e instanceof ContactRemovedEvent) { ContactId c = ((ContactRemovedEvent) e).getContactId(); connectionRecogniser.removeSecrets(c); + // FIXME: Listeners should not block synchronized(this) { removeAndEraseSecrets(c, oldSecrets); removeAndEraseSecrets(c, currentSecrets); @@ -307,12 +308,14 @@ class KeyManagerImpl extends TimerTask implements KeyManager, DatabaseListener { } } else if(e instanceof TransportAddedEvent) { TransportAddedEvent t = (TransportAddedEvent) e; + // FIXME: Listeners should not block synchronized(this) { maxLatencies.put(t.getTransportId(), t.getMaxLatency()); } } else if(e instanceof TransportRemovedEvent) { TransportId t = ((TransportRemovedEvent) e).getTransportId(); connectionRecogniser.removeSecrets(t); + // FIXME: Listeners should not block synchronized(this) { maxLatencies.remove(t); removeAndEraseSecrets(t, oldSecrets); diff --git a/briar-core/src/net/sf/briar/transport/TransportConnectionRecogniser.java b/briar-core/src/net/sf/briar/transport/TransportConnectionRecogniser.java index 93de724c68707cec6d5ec59085f31cab0601ed18..a827641d812babd5a72a91bf76c59d0d831c2f9c 100644 --- a/briar-core/src/net/sf/briar/transport/TransportConnectionRecogniser.java +++ b/briar-core/src/net/sf/briar/transport/TransportConnectionRecogniser.java @@ -20,6 +20,7 @@ import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.TemporarySecret; import net.sf.briar.util.ByteUtils; +// FIXME: Don't make alien calls with a lock held /** A connection recogniser for a specific transport. */ class TransportConnectionRecogniser { @@ -56,15 +57,15 @@ class TransportConnectionRecogniser { byte[] tag1 = new byte[TAG_LENGTH]; crypto.encodeTag(tag1, cipher, key, connection1); if(connection1 < connection) { - TagContext old = tagMap.remove(new Bytes(tag1)); - assert old != null; - ByteUtils.erase(old.context.getSecret()); + TagContext removed = tagMap.remove(new Bytes(tag1)); + assert removed != null; + ByteUtils.erase(removed.context.getSecret()); } else { ConnectionContext ctx1 = new ConnectionContext(contactId, transportId, secret.clone(), connection1, alice); TagContext tctx1 = new TagContext(window, ctx1, period); - TagContext old = tagMap.put(new Bytes(tag1), tctx1); - assert old == null; + TagContext duplicate = tagMap.put(new Bytes(tag1), tctx1); + assert duplicate == null; } } key.erase(); @@ -92,8 +93,8 @@ class TransportConnectionRecogniser { ConnectionContext ctx = new ConnectionContext(contactId, transportId, secret.clone(), connection, alice); TagContext tctx = new TagContext(window, ctx, period); - TagContext old = tagMap.put(new Bytes(tag), tctx); - assert old == null; + TagContext duplicate = tagMap.put(new Bytes(tag), tctx); + assert duplicate == null; } key.erase(); // Create a removal context to remove the window later @@ -116,9 +117,9 @@ class TransportConnectionRecogniser { byte[] tag = new byte[TAG_LENGTH]; for(long connection : rctx.window.getUnseen()) { crypto.encodeTag(tag, cipher, key, connection); - TagContext old = tagMap.remove(new Bytes(tag)); - assert old != null; - ByteUtils.erase(old.context.getSecret()); + TagContext removed = tagMap.remove(new Bytes(tag)); + assert removed != null; + ByteUtils.erase(removed.context.getSecret()); } key.erase(); ByteUtils.erase(rctx.secret); @@ -170,8 +171,8 @@ class TransportConnectionRecogniser { @Override public boolean equals(Object o) { if(o instanceof RemovalKey) { - RemovalKey w = (RemovalKey) o; - return contactId.equals(w.contactId) && period == w.period; + RemovalKey k = (RemovalKey) o; + return contactId.equals(k.contactId) && period == k.period; } return false; } diff --git a/briar-tests/build.xml b/briar-tests/build.xml index 8ed3ed747d8baf9b1f12208e5c371bb59b12d5f2..cf248ed5b3b87cdf199046e5403d75638173792a 100644 --- a/briar-tests/build.xml +++ b/briar-tests/build.xml @@ -104,6 +104,7 @@ <test name='net.sf.briar.transport.ConnectionWriterImplTest'/> <test name='net.sf.briar.transport.IncomingEncryptionLayerTest'/> <test name='net.sf.briar.transport.KeyManagerImplTest'/> + <test name='net.sf.briar.transport.KeyRotationIntegrationTest'/> <test name='net.sf.briar.transport.OutgoingEncryptionLayerTest'/> <test name='net.sf.briar.transport.TransportIntegrationTest'/> <test name='net.sf.briar.transport.TransportConnectionRecogniserTest'/> diff --git a/briar-tests/src/net/sf/briar/transport/KeyManagerImplTest.java b/briar-tests/src/net/sf/briar/transport/KeyManagerImplTest.java index 1e6bf41c9e36537b01b644cb77555f9c4f8da868..93c05d567b8ff4faa874ebcef9f7f5ed8522416b 100644 --- a/briar-tests/src/net/sf/briar/transport/KeyManagerImplTest.java +++ b/briar-tests/src/net/sf/briar/transport/KeyManagerImplTest.java @@ -1,10 +1,10 @@ package net.sf.briar.transport; import static net.sf.briar.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE; +import static org.junit.Assert.assertArrayEquals; import java.util.Arrays; import java.util.Collections; -import java.util.TimerTask; import net.sf.briar.BriarTestCase; import net.sf.briar.TestUtils; @@ -15,6 +15,7 @@ import net.sf.briar.api.clock.Timer; import net.sf.briar.api.crypto.CryptoComponent; import net.sf.briar.api.db.DatabaseComponent; import net.sf.briar.api.db.event.DatabaseListener; +import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionRecogniser; import net.sf.briar.api.transport.Endpoint; import net.sf.briar.api.transport.TemporarySecret; @@ -33,6 +34,7 @@ public class KeyManagerImplTest extends BriarTestCase { private final ContactId contactId; private final TransportId transportId; private final byte[] secret0, secret1, secret2, secret3, secret4; + private final byte[] initialSecret; public KeyManagerImplTest() { contactId = new ContactId(234); @@ -47,6 +49,8 @@ public class KeyManagerImplTest extends BriarTestCase { for(int i = 0; i < secret2.length; i++) secret2[i] = 3; for(int i = 0; i < secret3.length; i++) secret3[i] = 4; for(int i = 0; i < secret4.length; i++) secret4[i] = 5; + initialSecret = new byte[32]; + for(int i = 0; i < initialSecret.length; i++) initialSecret[i] = 123; } @Test @@ -71,7 +75,7 @@ public class KeyManagerImplTest extends BriarTestCase { will(returnValue(Collections.emptyMap())); oneOf(clock).currentTimeMillis(); will(returnValue(EPOCH)); - oneOf(timer).scheduleAtFixedRate(with(any(TimerTask.class)), + oneOf(timer).scheduleAtFixedRate(with(keyManager), with(any(long.class)), with(any(long.class))); // stop() oneOf(db).removeListener(with(any(DatabaseListener.class))); @@ -85,6 +89,133 @@ public class KeyManagerImplTest extends BriarTestCase { context.assertIsSatisfied(); } + @Test + public void testEndpointAdded() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ConnectionRecogniser connectionRecogniser = + context.mock(ConnectionRecogniser.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + // The secrets for periods 0 - 2 should be derived + Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2.clone()); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Collections.emptyList())); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + MAX_LATENCY))); + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // endpointAdded() during rotation period 1 + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + oneOf(crypto).deriveNextSecret(initialSecret, 0); + will(returnValue(secret0.clone())); + oneOf(crypto).deriveNextSecret(secret0, 1); + will(returnValue(secret1.clone())); + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(db).addSecrets(Arrays.asList(s0, s1, s2)); + // The secrets for periods 0 - 2 should be added to the recogniser + oneOf(connectionRecogniser).addSecret(s0); + oneOf(connectionRecogniser).addSecret(s1); + oneOf(connectionRecogniser).addSecret(s2); + // stop() + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + oneOf(connectionRecogniser).removeSecrets(); + }}); + + assertTrue(keyManager.start()); + keyManager.endpointAdded(ep, initialSecret.clone()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testEndpointAddedAndGetConnectionContext() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ConnectionRecogniser connectionRecogniser = + context.mock(ConnectionRecogniser.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + // The secrets for periods 0 - 2 should be derived + Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2.clone()); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Collections.emptyList())); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + MAX_LATENCY))); + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // endpointAdded() during rotation period 1 + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + oneOf(crypto).deriveNextSecret(initialSecret, 0); + will(returnValue(secret0.clone())); + oneOf(crypto).deriveNextSecret(secret0, 1); + will(returnValue(secret1.clone())); + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(db).addSecrets(Arrays.asList(s0, s1, s2)); + // The secrets for periods 0 - 2 should be added to the recogniser + oneOf(connectionRecogniser).addSecret(s0); + oneOf(connectionRecogniser).addSecret(s1); + oneOf(connectionRecogniser).addSecret(s2); + // getConnectionContext() + oneOf(db).incrementConnectionCounter(contactId, transportId, 1); + will(returnValue(0L)); + // stop() + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + oneOf(connectionRecogniser).removeSecrets(); + }}); + + assertTrue(keyManager.start()); + keyManager.endpointAdded(ep, initialSecret.clone()); + ConnectionContext ctx = + keyManager.getConnectionContext(contactId, transportId); + assertNotNull(ctx); + assertEquals(contactId, ctx.getContactId()); + assertEquals(transportId, ctx.getTransportId()); + assertArrayEquals(secret1, ctx.getSecret()); + assertEquals(0, ctx.getConnectionNumber()); + assertEquals(true, ctx.getAlice()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + @Test public void testLoadSecretsAtEpoch() throws Exception { Mockery context = new Mockery(); @@ -98,7 +229,7 @@ public class KeyManagerImplTest extends BriarTestCase { final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, connectionRecogniser, clock, timer); - // The DB contains secrets for periods 0 - 2 + // The DB contains the secrets for periods 0 - 2 Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); @@ -119,7 +250,7 @@ public class KeyManagerImplTest extends BriarTestCase { oneOf(connectionRecogniser).addSecret(s0); oneOf(connectionRecogniser).addSecret(s1); oneOf(connectionRecogniser).addSecret(s2); - oneOf(timer).scheduleAtFixedRate(with(any(TimerTask.class)), + oneOf(timer).scheduleAtFixedRate(with(keyManager), with(any(long.class)), with(any(long.class))); // stop() oneOf(db).removeListener(with(any(DatabaseListener.class))); @@ -146,7 +277,7 @@ public class KeyManagerImplTest extends BriarTestCase { final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, connectionRecogniser, clock, timer); - // The DB contains secrets for periods 0 - 2 + // The DB contains the secrets for periods 0 - 2 Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); @@ -177,7 +308,7 @@ public class KeyManagerImplTest extends BriarTestCase { oneOf(connectionRecogniser).addSecret(s1); oneOf(connectionRecogniser).addSecret(s2); oneOf(connectionRecogniser).addSecret(s3); - oneOf(timer).scheduleAtFixedRate(with(any(TimerTask.class)), + oneOf(timer).scheduleAtFixedRate(with(keyManager), with(any(long.class)), with(any(long.class))); // stop() oneOf(db).removeListener(with(any(DatabaseListener.class))); @@ -192,7 +323,7 @@ public class KeyManagerImplTest extends BriarTestCase { } @Test - public void testLoadSecretsAtStartOfPeriod3() throws Exception { + public void testLoadSecretsAtEndOfPeriod3() throws Exception { Mockery context = new Mockery(); final CryptoComponent crypto = context.mock(CryptoComponent.class); final DatabaseComponent db = context.mock(DatabaseComponent.class); @@ -204,7 +335,7 @@ public class KeyManagerImplTest extends BriarTestCase { final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, connectionRecogniser, clock, timer); - // The DB contains secrets for periods 0 - 2 + // The DB contains the secrets for periods 0 - 2 Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); @@ -221,9 +352,9 @@ public class KeyManagerImplTest extends BriarTestCase { oneOf(db).getTransportLatencies(); will(returnValue(Collections.singletonMap(transportId, MAX_LATENCY))); - // The current time is the start of period 3 + // The current time is the end of period 3 oneOf(clock).currentTimeMillis(); - will(returnValue(EPOCH + 2 * ROTATION_PERIOD_LENGTH)); + will(returnValue(EPOCH + 3 * ROTATION_PERIOD_LENGTH - 1)); // The secrets for periods 3 and 4 should be derived from secret 0 oneOf(crypto).deriveNextSecret(secret0, 1); will(returnValue(secret1.clone())); @@ -246,8 +377,221 @@ public class KeyManagerImplTest extends BriarTestCase { oneOf(connectionRecogniser).addSecret(s2); oneOf(connectionRecogniser).addSecret(s3); oneOf(connectionRecogniser).addSecret(s4); - oneOf(timer).scheduleAtFixedRate(with(any(TimerTask.class)), + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // stop() + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + oneOf(connectionRecogniser).removeSecrets(); + }}); + + assertTrue(keyManager.start()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testLoadSecretsAndRotateInSamePeriod() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ConnectionRecogniser connectionRecogniser = + context.mock(ConnectionRecogniser.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + // The DB contains the secrets for periods 0 - 2 + Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2.clone()); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Arrays.asList(s0, s1, s2))); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + MAX_LATENCY))); + // The current time is the epoch, the start of period 1 + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + // The secrets for periods 0 - 2 should be added to the recogniser + oneOf(connectionRecogniser).addSecret(s0); + oneOf(connectionRecogniser).addSecret(s1); + oneOf(connectionRecogniser).addSecret(s2); + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // run() during period 1: the secrets should not be affected + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH + 1)); + // getConnectionContext() + oneOf(db).incrementConnectionCounter(contactId, transportId, 1); + will(returnValue(0L)); + // stop() + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + oneOf(connectionRecogniser).removeSecrets(); + }}); + + assertTrue(keyManager.start()); + keyManager.run(); + ConnectionContext ctx = + keyManager.getConnectionContext(contactId, transportId); + assertNotNull(ctx); + assertEquals(contactId, ctx.getContactId()); + assertEquals(transportId, ctx.getTransportId()); + assertArrayEquals(secret1, ctx.getSecret()); + assertEquals(0, ctx.getConnectionNumber()); + assertEquals(true, ctx.getAlice()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testLoadSecretsAndRotateInNextPeriod() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ConnectionRecogniser connectionRecogniser = + context.mock(ConnectionRecogniser.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + // The DB contains the secrets for periods 0 - 2 + Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2.clone()); + // The secret for period 3 should be derived and stored + final TemporarySecret s3 = new TemporarySecret(ep, 3, secret3.clone()); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Arrays.asList(s0, s1, s2))); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + MAX_LATENCY))); + // The current time is the epoch, the start of period 1 + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + // The secrets for periods 0 - 2 should be added to the recogniser + oneOf(connectionRecogniser).addSecret(s0); + oneOf(connectionRecogniser).addSecret(s1); + oneOf(connectionRecogniser).addSecret(s2); + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // run() during period 2: the secrets should be rotated + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH + ROTATION_PERIOD_LENGTH + 1)); + oneOf(crypto).deriveNextSecret(secret0, 1); + will(returnValue(secret1.clone())); + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(crypto).deriveNextSecret(secret2, 3); + will(returnValue(secret3.clone())); + oneOf(connectionRecogniser).removeSecret(contactId, transportId, 0); + oneOf(db).addSecrets(Arrays.asList(s3)); + oneOf(connectionRecogniser).addSecret(s3); + // getConnectionContext() + oneOf(db).incrementConnectionCounter(contactId, transportId, 2); + will(returnValue(0L)); + // stop() + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + oneOf(connectionRecogniser).removeSecrets(); + }}); + + assertTrue(keyManager.start()); + keyManager.run(); + ConnectionContext ctx = + keyManager.getConnectionContext(contactId, transportId); + assertNotNull(ctx); + assertEquals(contactId, ctx.getContactId()); + assertEquals(transportId, ctx.getTransportId()); + assertArrayEquals(secret2, ctx.getSecret()); + assertEquals(0, ctx.getConnectionNumber()); + assertEquals(true, ctx.getAlice()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testLoadSecretsAndRotateAWholePeriodLate() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final ConnectionRecogniser connectionRecogniser = + context.mock(ConnectionRecogniser.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + // The DB contains the secrets for periods 0 - 2 + Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2.clone()); + // The secrets for periods 3 and 4 should be derived and stored + final TemporarySecret s3 = new TemporarySecret(ep, 3, secret3.clone()); + final TemporarySecret s4 = new TemporarySecret(ep, 4, secret4.clone()); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Arrays.asList(s0, s1, s2))); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + MAX_LATENCY))); + // The current time is the epoch, the start of period 1 + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + // The secrets for periods 0 - 2 should be added to the recogniser + oneOf(connectionRecogniser).addSecret(s0); + oneOf(connectionRecogniser).addSecret(s1); + oneOf(connectionRecogniser).addSecret(s2); + oneOf(timer).scheduleAtFixedRate(with(keyManager), with(any(long.class)), with(any(long.class))); + // run() during period 3 (late): the secrets should be rotated + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH + 2 * ROTATION_PERIOD_LENGTH + 1)); + oneOf(crypto).deriveNextSecret(secret0, 1); + will(returnValue(secret1.clone())); + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(crypto).deriveNextSecret(secret2, 3); + will(returnValue(secret3.clone())); + oneOf(crypto).deriveNextSecret(secret3, 4); + will(returnValue(secret4.clone())); + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(crypto).deriveNextSecret(secret2, 3); + will(returnValue(secret3.clone())); + oneOf(crypto).deriveNextSecret(secret3, 4); + will(returnValue(secret4.clone())); + oneOf(connectionRecogniser).removeSecret(contactId, transportId, 0); + oneOf(connectionRecogniser).removeSecret(contactId, transportId, 1); + oneOf(db).addSecrets(Arrays.asList(s3, s4)); + oneOf(connectionRecogniser).addSecret(s3); + oneOf(connectionRecogniser).addSecret(s4); + // getConnectionContext() + oneOf(db).incrementConnectionCounter(contactId, transportId, 3); + will(returnValue(0L)); // stop() oneOf(db).removeListener(with(any(DatabaseListener.class))); oneOf(timer).cancel(); @@ -255,6 +599,15 @@ public class KeyManagerImplTest extends BriarTestCase { }}); assertTrue(keyManager.start()); + keyManager.run(); + ConnectionContext ctx = + keyManager.getConnectionContext(contactId, transportId); + assertNotNull(ctx); + assertEquals(contactId, ctx.getContactId()); + assertEquals(transportId, ctx.getTransportId()); + assertArrayEquals(secret3, ctx.getSecret()); + assertEquals(0, ctx.getConnectionNumber()); + assertEquals(true, ctx.getAlice()); keyManager.stop(); context.assertIsSatisfied(); diff --git a/briar-tests/src/net/sf/briar/transport/KeyRotationIntegrationTest.java b/briar-tests/src/net/sf/briar/transport/KeyRotationIntegrationTest.java new file mode 100644 index 0000000000000000000000000000000000000000..ed5230b8337acac5580bdb7a548d7a0d83accee9 --- /dev/null +++ b/briar-tests/src/net/sf/briar/transport/KeyRotationIntegrationTest.java @@ -0,0 +1,983 @@ +package net.sf.briar.transport; + +import static net.sf.briar.api.transport.TransportConstants.MAX_CLOCK_DIFFERENCE; +import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH; +import static org.junit.Assert.assertArrayEquals; + +import java.util.Arrays; +import java.util.Collections; + +import javax.crypto.Cipher; +import javax.crypto.NullCipher; + +import net.sf.briar.BriarTestCase; +import net.sf.briar.TestUtils; +import net.sf.briar.api.ContactId; +import net.sf.briar.api.TransportId; +import net.sf.briar.api.clock.Clock; +import net.sf.briar.api.clock.Timer; +import net.sf.briar.api.crypto.CryptoComponent; +import net.sf.briar.api.crypto.ErasableKey; +import net.sf.briar.api.db.DatabaseComponent; +import net.sf.briar.api.db.event.DatabaseListener; +import net.sf.briar.api.transport.ConnectionContext; +import net.sf.briar.api.transport.ConnectionRecogniser; +import net.sf.briar.api.transport.Endpoint; +import net.sf.briar.api.transport.TemporarySecret; +import net.sf.briar.util.ByteUtils; + +import org.hamcrest.Description; +import org.jmock.Expectations; +import org.jmock.Mockery; +import org.jmock.api.Action; +import org.jmock.api.Invocation; +import org.junit.Test; + +public class KeyRotationIntegrationTest extends BriarTestCase { + + private static final long EPOCH = 1000L * 1000L * 1000L * 1000L; + private static final long MAX_LATENCY = 2 * 60 * 1000; // 2 minutes + private static final long ROTATION_PERIOD_LENGTH = + MAX_LATENCY + MAX_CLOCK_DIFFERENCE; + + private final ContactId contactId; + private final TransportId transportId; + private final byte[] secret0, secret1, secret2, secret3, secret4; + private final byte[] key0, key1, key2, key3, key4; + private final byte[] initialSecret; + private final Cipher cipher; + + public KeyRotationIntegrationTest() { + contactId = new ContactId(234); + transportId = new TransportId(TestUtils.getRandomId()); + secret0 = new byte[32]; + secret1 = new byte[32]; + secret2 = new byte[32]; + secret3 = new byte[32]; + secret4 = new byte[32]; + for(int i = 0; i < secret0.length; i++) secret0[i] = 1; + for(int i = 0; i < secret1.length; i++) secret1[i] = 2; + for(int i = 0; i < secret2.length; i++) secret2[i] = 3; + for(int i = 0; i < secret3.length; i++) secret3[i] = 4; + for(int i = 0; i < secret4.length; i++) secret4[i] = 5; + key0 = new byte[32]; + key1 = new byte[32]; + key2 = new byte[32]; + key3 = new byte[32]; + key4 = new byte[32]; + for(int i = 0; i < key0.length; i++) key0[i] = 1; + for(int i = 0; i < key1.length; i++) key1[i] = 2; + for(int i = 0; i < key2.length; i++) key2[i] = 3; + for(int i = 0; i < key3.length; i++) key3[i] = 4; + for(int i = 0; i < key4.length; i++) key4[i] = 5; + initialSecret = new byte[32]; + for(int i = 0; i < initialSecret.length; i++) initialSecret[i] = 123; + cipher = new NullCipher(); + } + + @Test + public void testStartAndStop() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + + final ConnectionRecogniser connectionRecogniser = + new ConnectionRecogniserImpl(crypto, db); + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Collections.emptyList())); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.emptyMap())); + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // stop() + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + }}); + + assertTrue(keyManager.start()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testEndpointAdded() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + final ErasableKey k0 = context.mock(ErasableKey.class, "k0"); + final ErasableKey k1 = context.mock(ErasableKey.class, "k1"); + final ErasableKey k2 = context.mock(ErasableKey.class, "k2"); + + final ConnectionRecogniser connectionRecogniser = + new ConnectionRecogniserImpl(crypto, db); + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + // The secrets for periods 0 - 2 should be derived + Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2.clone()); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Collections.emptyList())); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + MAX_LATENCY))); + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // endpointAdded() during rotation period 1 + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + oneOf(crypto).deriveNextSecret(initialSecret, 0); + will(returnValue(secret0.clone())); + oneOf(crypto).deriveNextSecret(secret0, 1); + will(returnValue(secret1.clone())); + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(db).addSecrets(Arrays.asList(s0, s1, s2)); + // The recogniser should derive the tags for period 0 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret0, false); + will(returnValue(k0)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k0), with((long) i)); + will(new EncodeTagAction()); + oneOf(k0).getEncoded(); + will(returnValue(key0)); + } + oneOf(k0).erase(); + // The recogniser should derive the tags for period 1 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret1, false); + will(returnValue(k1)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k1), with((long) i)); + will(new EncodeTagAction()); + oneOf(k1).getEncoded(); + will(returnValue(key1)); + } + oneOf(k1).erase(); + // The recogniser should derive the tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // stop() + // The recogniser should derive the tags for period 0 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret0, false); + will(returnValue(k0)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k0), with((long) i)); + will(new EncodeTagAction()); + oneOf(k0).getEncoded(); + will(returnValue(key0)); + } + oneOf(k0).erase(); + // The recogniser should derive the tags for period 1 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret1, false); + will(returnValue(k1)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k1), with((long) i)); + will(new EncodeTagAction()); + oneOf(k1).getEncoded(); + will(returnValue(key1)); + } + oneOf(k1).erase(); + // The recogniser should derive the tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // Remove the listener and stop the timer + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + }}); + + assertTrue(keyManager.start()); + keyManager.endpointAdded(ep, initialSecret.clone()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testEndpointAddedAndGetConnectionContext() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + final ErasableKey k0 = context.mock(ErasableKey.class, "k0"); + final ErasableKey k1 = context.mock(ErasableKey.class, "k1"); + final ErasableKey k2 = context.mock(ErasableKey.class, "k2"); + + final ConnectionRecogniser connectionRecogniser = + new ConnectionRecogniserImpl(crypto, db); + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + // The secrets for periods 0 - 2 should be derived + Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2.clone()); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Collections.emptyList())); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + MAX_LATENCY))); + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // endpointAdded() during rotation period 1 + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + oneOf(crypto).deriveNextSecret(initialSecret, 0); + will(returnValue(secret0.clone())); + oneOf(crypto).deriveNextSecret(secret0, 1); + will(returnValue(secret1.clone())); + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(db).addSecrets(Arrays.asList(s0, s1, s2)); + // The recogniser should derive the tags for period 0 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret0, false); + will(returnValue(k0)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k0), with((long) i)); + will(new EncodeTagAction()); + oneOf(k0).getEncoded(); + will(returnValue(key0)); + } + oneOf(k0).erase(); + // The recogniser should derive the tags for period 1 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret1, false); + will(returnValue(k1)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k1), with((long) i)); + will(new EncodeTagAction()); + oneOf(k1).getEncoded(); + will(returnValue(key1)); + } + oneOf(k1).erase(); + // The recogniser should derive the tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // getConnectionContext() + oneOf(db).incrementConnectionCounter(contactId, transportId, 1); + will(returnValue(0L)); + // stop() + // The recogniser should derive the tags for period 0 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret0, false); + will(returnValue(k0)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k0), with((long) i)); + will(new EncodeTagAction()); + oneOf(k0).getEncoded(); + will(returnValue(key0)); + } + oneOf(k0).erase(); + // The recogniser should derive the tags for period 1 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret1, false); + will(returnValue(k1)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k1), with((long) i)); + will(new EncodeTagAction()); + oneOf(k1).getEncoded(); + will(returnValue(key1)); + } + oneOf(k1).erase(); + // The recogniser should derive the tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // Remove the listener and stop the timer + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + }}); + + assertTrue(keyManager.start()); + keyManager.endpointAdded(ep, initialSecret.clone()); + ConnectionContext ctx = + keyManager.getConnectionContext(contactId, transportId); + assertNotNull(ctx); + assertEquals(contactId, ctx.getContactId()); + assertEquals(transportId, ctx.getTransportId()); + assertArrayEquals(secret1, ctx.getSecret()); + assertEquals(0, ctx.getConnectionNumber()); + assertEquals(true, ctx.getAlice()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testEndpointAddedAndAcceptConnection() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + final ErasableKey k0 = context.mock(ErasableKey.class, "k0"); + final ErasableKey k1 = context.mock(ErasableKey.class, "k1"); + final ErasableKey k2 = context.mock(ErasableKey.class, "k2"); + + final ConnectionRecogniser connectionRecogniser = + new ConnectionRecogniserImpl(crypto, db); + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + // The secrets for periods 0 - 2 should be derived + Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2.clone()); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Collections.emptyList())); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + MAX_LATENCY))); + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // endpointAdded() during rotation period 1 + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + oneOf(crypto).deriveNextSecret(initialSecret, 0); + will(returnValue(secret0.clone())); + oneOf(crypto).deriveNextSecret(secret0, 1); + will(returnValue(secret1.clone())); + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(db).addSecrets(Arrays.asList(s0, s1, s2)); + // The recogniser should derive the tags for period 0 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret0, false); + will(returnValue(k0)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k0), with((long) i)); + will(new EncodeTagAction()); + oneOf(k0).getEncoded(); + will(returnValue(key0)); + } + oneOf(k0).erase(); + // The recogniser should derive the tags for period 1 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret1, false); + will(returnValue(k1)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k1), with((long) i)); + will(new EncodeTagAction()); + oneOf(k1).getEncoded(); + will(returnValue(key1)); + } + oneOf(k1).erase(); + // The recogniser should derive the tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // acceptConnection() + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with(16L)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + oneOf(db).setConnectionWindow(contactId, transportId, 2, 1, + new byte[] {0, 1, 0, 0}); + oneOf(k2).erase(); + // stop() + // The recogniser should derive the tags for period 0 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret0, false); + will(returnValue(k0)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k0), with((long) i)); + will(new EncodeTagAction()); + oneOf(k0).getEncoded(); + will(returnValue(key0)); + } + oneOf(k0).erase(); + // The recogniser should derive the tags for period 1 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret1, false); + will(returnValue(k1)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k1), with((long) i)); + will(new EncodeTagAction()); + oneOf(k1).getEncoded(); + will(returnValue(key1)); + } + oneOf(k1).erase(); + // The recogniser should derive the updated tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 1; i < 17; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // Remove the listener and stop the timer + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + }}); + + assertTrue(keyManager.start()); + keyManager.endpointAdded(ep, initialSecret.clone()); + // Recognise the tag for connection 0 in period 2 + byte[] tag = new byte[TAG_LENGTH]; + encodeTag(tag, key2, 0); + ConnectionContext ctx = + connectionRecogniser.acceptConnection(transportId, tag); + assertNotNull(ctx); + assertEquals(contactId, ctx.getContactId()); + assertEquals(transportId, ctx.getTransportId()); + assertArrayEquals(secret2, ctx.getSecret()); + assertEquals(0, ctx.getConnectionNumber()); + assertEquals(true, ctx.getAlice()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testLoadSecretsAtEpoch() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + final ErasableKey k0 = context.mock(ErasableKey.class, "k0"); + final ErasableKey k1 = context.mock(ErasableKey.class, "k1"); + final ErasableKey k2 = context.mock(ErasableKey.class, "k2"); + + final ConnectionRecogniser connectionRecogniser = + new ConnectionRecogniserImpl(crypto, db); + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + // The DB contains the secrets for periods 0 - 2 + Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2.clone()); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Arrays.asList(s0, s1, s2))); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + MAX_LATENCY))); + // The current time is the epoch, the start of period 1 + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH)); + // The recogniser should derive the tags for period 0 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret0, false); + will(returnValue(k0)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k0), with((long) i)); + will(new EncodeTagAction()); + oneOf(k0).getEncoded(); + will(returnValue(key0)); + } + oneOf(k0).erase(); + // The recogniser should derive the tags for period 1 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret1, false); + will(returnValue(k1)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k1), with((long) i)); + will(new EncodeTagAction()); + oneOf(k1).getEncoded(); + will(returnValue(key1)); + } + oneOf(k1).erase(); + // The recogniser should derive the tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // Start the timer + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // stop() + // The recogniser should remove the tags for period 0 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret0, false); + will(returnValue(k0)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k0), with((long) i)); + will(new EncodeTagAction()); + oneOf(k0).getEncoded(); + will(returnValue(key0)); + } + oneOf(k0).erase(); + // The recogniser should derive the tags for period 1 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret1, false); + will(returnValue(k1)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k1), with((long) i)); + will(new EncodeTagAction()); + oneOf(k1).getEncoded(); + will(returnValue(key1)); + } + oneOf(k1).erase(); + // The recogniser should derive the tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // Remove the listener and stop the timer + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + }}); + + assertTrue(keyManager.start()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testLoadSecretsAtStartOfPeriod2() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + final ErasableKey k1 = context.mock(ErasableKey.class, "k1"); + final ErasableKey k2 = context.mock(ErasableKey.class, "k2"); + final ErasableKey k3 = context.mock(ErasableKey.class, "k3"); + + final ConnectionRecogniser connectionRecogniser = + new ConnectionRecogniserImpl(crypto, db); + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + // The DB contains the secrets for periods 0 - 2 + Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2.clone()); + // The secret for period 3 should be derived and stored + final TemporarySecret s3 = new TemporarySecret(ep, 3, secret3.clone()); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Arrays.asList(s0, s1, s2))); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + MAX_LATENCY))); + // The current time is the start of period 2 + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH + ROTATION_PERIOD_LENGTH)); + // The secret for period 3 should be derived and stored + oneOf(crypto).deriveNextSecret(secret0, 1); + will(returnValue(secret1.clone())); + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(crypto).deriveNextSecret(secret2, 3); + will(returnValue(secret3.clone())); + oneOf(db).addSecrets(Arrays.asList(s3)); + // The recogniser should derive the tags for period 1 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret1, false); + will(returnValue(k1)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k1), with((long) i)); + will(new EncodeTagAction()); + oneOf(k1).getEncoded(); + will(returnValue(key1)); + } + oneOf(k1).erase(); + // The recogniser should derive the tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // The recogniser should derive the tags for period 3 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret3, false); + will(returnValue(k3)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k3), with((long) i)); + will(new EncodeTagAction()); + oneOf(k3).getEncoded(); + will(returnValue(key3)); + } + oneOf(k3).erase(); + // Start the timer + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // stop() + // The recogniser should derive the tags for period 1 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret1, false); + will(returnValue(k1)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k1), with((long) i)); + will(new EncodeTagAction()); + oneOf(k1).getEncoded(); + will(returnValue(key1)); + } + oneOf(k1).erase(); + // The recogniser should derive the tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // The recogniser should remove the tags for period 3 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret3, false); + will(returnValue(k3)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k3), with((long) i)); + will(new EncodeTagAction()); + oneOf(k3).getEncoded(); + will(returnValue(key3)); + } + oneOf(k3).erase(); + // Remove the listener and stop the timer + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + }}); + + assertTrue(keyManager.start()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + @Test + public void testLoadSecretsAtEndOfPeriod3() throws Exception { + Mockery context = new Mockery(); + final CryptoComponent crypto = context.mock(CryptoComponent.class); + final DatabaseComponent db = context.mock(DatabaseComponent.class); + final Clock clock = context.mock(Clock.class); + final Timer timer = context.mock(Timer.class); + final ErasableKey k2 = context.mock(ErasableKey.class, "k2"); + final ErasableKey k3 = context.mock(ErasableKey.class, "k3"); + final ErasableKey k4 = context.mock(ErasableKey.class, "k4"); + + final ConnectionRecogniser connectionRecogniser = + new ConnectionRecogniserImpl(crypto, db); + final KeyManagerImpl keyManager = new KeyManagerImpl(crypto, db, + connectionRecogniser, clock, timer); + + // The DB contains the secrets for periods 0 - 2 + Endpoint ep = new Endpoint(contactId, transportId, EPOCH, true); + final TemporarySecret s0 = new TemporarySecret(ep, 0, secret0.clone()); + final TemporarySecret s1 = new TemporarySecret(ep, 1, secret1.clone()); + final TemporarySecret s2 = new TemporarySecret(ep, 2, secret2.clone()); + // The secrets for periods 3 and 4 should be derived and stored + final TemporarySecret s3 = new TemporarySecret(ep, 3, secret3.clone()); + final TemporarySecret s4 = new TemporarySecret(ep, 4, secret4.clone()); + + context.checking(new Expectations() {{ + // start() + oneOf(db).addListener(with(any(DatabaseListener.class))); + oneOf(db).getSecrets(); + will(returnValue(Arrays.asList(s0, s1, s2))); + oneOf(db).getTransportLatencies(); + will(returnValue(Collections.singletonMap(transportId, + MAX_LATENCY))); + // The current time is the end of period 3 + oneOf(clock).currentTimeMillis(); + will(returnValue(EPOCH + 3 * ROTATION_PERIOD_LENGTH - 1)); + // The secrets for periods 3 and 4 should be derived from secret 0 + oneOf(crypto).deriveNextSecret(secret0, 1); + will(returnValue(secret1.clone())); + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(crypto).deriveNextSecret(secret2, 3); + will(returnValue(secret3.clone())); + oneOf(crypto).deriveNextSecret(secret3, 4); + will(returnValue(secret4.clone())); + // The secrets for periods 3 and 4 should be derived from secret 1 + oneOf(crypto).deriveNextSecret(secret1, 2); + will(returnValue(secret2.clone())); + oneOf(crypto).deriveNextSecret(secret2, 3); + will(returnValue(secret3.clone())); + oneOf(crypto).deriveNextSecret(secret3, 4); + will(returnValue(secret4.clone())); + // One copy of each of the new secrets should be stored + oneOf(db).addSecrets(Arrays.asList(s3, s4)); + // The recogniser should derive the tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // The recogniser should derive the tags for period 3 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret3, false); + will(returnValue(k3)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k3), with((long) i)); + will(new EncodeTagAction()); + oneOf(k3).getEncoded(); + will(returnValue(key3)); + } + oneOf(k3).erase(); + // The recogniser should derive the tags for period 4 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret4, false); + will(returnValue(k4)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k4), with((long) i)); + will(new EncodeTagAction()); + oneOf(k4).getEncoded(); + will(returnValue(key4)); + } + oneOf(k4).erase(); + // Start the timer + oneOf(timer).scheduleAtFixedRate(with(keyManager), + with(any(long.class)), with(any(long.class))); + // stop() + // The recogniser should derive the tags for period 2 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret2, false); + will(returnValue(k2)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k2), with((long) i)); + will(new EncodeTagAction()); + oneOf(k2).getEncoded(); + will(returnValue(key2)); + } + oneOf(k2).erase(); + // The recogniser should remove the tags for period 3 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret3, false); + will(returnValue(k3)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k3), with((long) i)); + will(new EncodeTagAction()); + oneOf(k3).getEncoded(); + will(returnValue(key3)); + } + oneOf(k3).erase(); + // The recogniser should derive the tags for period 4 + oneOf(crypto).getTagCipher(); + will(returnValue(cipher)); + oneOf(crypto).deriveTagKey(secret4, false); + will(returnValue(k4)); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), with(cipher), + with(k4), with((long) i)); + will(new EncodeTagAction()); + oneOf(k4).getEncoded(); + will(returnValue(key4)); + } + oneOf(k4).erase(); + // Remove the listener and stop the timer + oneOf(db).removeListener(with(any(DatabaseListener.class))); + oneOf(timer).cancel(); + }}); + + assertTrue(keyManager.start()); + keyManager.stop(); + + context.assertIsSatisfied(); + } + + private void encodeTag(byte[] tag, byte[] rawKey, long connection) { + // Encode a fake tag based on the key and connection number + System.arraycopy(rawKey, 0, tag, 0, tag.length); + ByteUtils.writeUint32(connection, tag, 0); + } + + private class EncodeTagAction implements Action { + + public void describeTo(Description description) { + description.appendText("Encodes a tag"); + } + + public Object invoke(Invocation invocation) throws Throwable { + byte[] tag = (byte[]) invocation.getParameter(0); + ErasableKey key = (ErasableKey) invocation.getParameter(2); + long connection = (Long) invocation.getParameter(3); + byte[] rawKey = key.getEncoded(); + encodeTag(tag, rawKey, connection); + return null; + } + } +} diff --git a/briar-tests/src/net/sf/briar/transport/TransportConnectionRecogniserTest.java b/briar-tests/src/net/sf/briar/transport/TransportConnectionRecogniserTest.java index dddcc83d07c4ddf94f66ba10d479271706bc9646..fe6bd2cf567889d179adaf94977c3038ca286a20 100644 --- a/briar-tests/src/net/sf/briar/transport/TransportConnectionRecogniserTest.java +++ b/briar-tests/src/net/sf/briar/transport/TransportConnectionRecogniserTest.java @@ -31,12 +31,12 @@ public class TransportConnectionRecogniserTest extends BriarTestCase { private final ContactId contactId = new ContactId(234); private final TransportId transportId = new TransportId(TestUtils.getRandomId()); + private final Cipher tagCipher = new NullCipher(); @Test public void testAddAndRemoveSecret() { Mockery context = new Mockery(); final CryptoComponent crypto = context.mock(CryptoComponent.class); - final Cipher tagCipher = new NullCipher(); final byte[] secret = new byte[32]; new Random().nextBytes(secret); final boolean alice = false; @@ -48,18 +48,22 @@ public class TransportConnectionRecogniserTest extends BriarTestCase { will(returnValue(tagCipher)); oneOf(crypto).deriveTagKey(secret, !alice); will(returnValue(tagKey)); - exactly(16).of(crypto).encodeTag(with(any(byte[].class)), - with(tagCipher), with(tagKey), with(any(long.class))); - will(new EncodeTagAction()); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), + with(tagCipher), with(tagKey), with((long) i)); + will(new EncodeTagAction()); + } oneOf(tagKey).erase(); // Remove secret oneOf(crypto).getTagCipher(); will(returnValue(tagCipher)); oneOf(crypto).deriveTagKey(secret, !alice); will(returnValue(tagKey)); - exactly(16).of(crypto).encodeTag(with(any(byte[].class)), - with(tagCipher), with(tagKey), with(any(long.class))); - will(new EncodeTagAction()); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), + with(tagCipher), with(tagKey), with((long) i)); + will(new EncodeTagAction()); + } oneOf(tagKey).erase(); }}); TemporarySecret s = new TemporarySecret(contactId, transportId, 123, @@ -77,7 +81,6 @@ public class TransportConnectionRecogniserTest extends BriarTestCase { public void testAcceptConnection() throws Exception { Mockery context = new Mockery(); final CryptoComponent crypto = context.mock(CryptoComponent.class); - final Cipher tagCipher = new NullCipher(); final byte[] secret = new byte[32]; new Random().nextBytes(secret); final boolean alice = false; @@ -89,11 +92,13 @@ public class TransportConnectionRecogniserTest extends BriarTestCase { will(returnValue(tagCipher)); oneOf(crypto).deriveTagKey(secret, !alice); will(returnValue(tagKey)); - exactly(16).of(crypto).encodeTag(with(any(byte[].class)), - with(tagCipher), with(tagKey), with(any(long.class))); - will(new EncodeTagAction()); + for(int i = 0; i < 16; i++) { + oneOf(crypto).encodeTag(with(any(byte[].class)), + with(tagCipher), with(tagKey), with((long) i)); + will(new EncodeTagAction()); + } oneOf(tagKey).erase(); - // Accept connection + // Accept connection 0 oneOf(crypto).getTagCipher(); will(returnValue(tagCipher)); oneOf(crypto).deriveTagKey(secret, !alice); @@ -134,6 +139,7 @@ public class TransportConnectionRecogniserTest extends BriarTestCase { public Object invoke(Invocation invocation) throws Throwable { byte[] tag = (byte[]) invocation.getParameter(0); long connection = (Long) invocation.getParameter(3); + // Encode a fake tag based on the connection number ByteUtils.writeUint32(connection, tag, 0); return null; }